summaryrefslogtreecommitdiffstats
path: root/searchlib
diff options
context:
space:
mode:
authorGeir Storli <geirst@verizonmedia.com>2019-06-12 16:17:34 +0200
committerGitHub <noreply@github.com>2019-06-12 16:17:34 +0200
commit1f06d0d5f6218dafc1a0a0dc1dae85e86cb6ac90 (patch)
treee93c5dc8a8581aa17e48e5a0abecbab5f13b6ea5 /searchlib
parentd3e5bf770cf893f7a9196fa49b07eae651d3c688 (diff)
parent3704383893ba63e3ea674881f8cd24da6d7683d8 (diff)
Merge pull request #9763 from vespa-engine/geirst/idf-calculation-in-bm25-feature
Geirst/idf calculation in bm25 feature
Diffstat (limited to 'searchlib')
-rw-r--r--searchlib/src/tests/features/bm25/bm25_test.cpp43
-rw-r--r--searchlib/src/tests/features/prod_features.cpp6
-rw-r--r--searchlib/src/tests/fef/termfieldmodel/termfieldmodel_test.cpp2
-rw-r--r--searchlib/src/vespa/searchlib/features/bm25_feature.cpp14
-rw-r--r--searchlib/src/vespa/searchlib/features/bm25_feature.h2
-rw-r--r--searchlib/src/vespa/searchlib/fef/itermfielddata.h17
-rw-r--r--searchlib/src/vespa/searchlib/fef/simpletermfielddata.cpp12
-rw-r--r--searchlib/src/vespa/searchlib/fef/simpletermfielddata.h40
8 files changed, 82 insertions, 54 deletions
diff --git a/searchlib/src/tests/features/bm25/bm25_test.cpp b/searchlib/src/tests/features/bm25/bm25_test.cpp
index 1a3895d7e28..6f0d4b80051 100644
--- a/searchlib/src/tests/features/bm25/bm25_test.cpp
+++ b/searchlib/src/tests/features/bm25/bm25_test.cpp
@@ -86,6 +86,7 @@ struct Bm25ExecutorTest : public ::testing::Test {
BlueprintFactory factory;
FtFeatureTest test;
test::MatchDataBuilder::UP match_data;
+ static constexpr uint32_t total_doc_count = 100;
Bm25ExecutorTest()
: factory(),
@@ -95,11 +96,15 @@ struct Bm25ExecutorTest : public ::testing::Test {
setup_search_features(factory);
test.getIndexEnv().getBuilder().addField(FieldType::INDEX, CollectionType::SINGLE, "foo");
test.getIndexEnv().getBuilder().addField(FieldType::INDEX, CollectionType::SINGLE, "bar");
- test.getQueryEnv().getBuilder().addIndexNode({"foo"});
- test.getQueryEnv().getBuilder().addIndexNode({"foo"});
- test.getQueryEnv().getBuilder().addIndexNode({"bar"});
+ add_query_term("foo", 25);
+ add_query_term("foo", 35);
+ add_query_term("bar", 45);
test.getQueryEnv().getBuilder().set_avg_field_length("foo", 10);
}
+ void add_query_term(const vespalib::string& field_name, uint32_t matching_doc_count) {
+ auto* term = test.getQueryEnv().getBuilder().addIndexNode({field_name});
+ term->field(0).setDocFreq(matching_doc_count, total_doc_count);
+ }
void setup() {
EXPECT_TRUE(test.setup());
match_data = test.createMatchDataBuilder();
@@ -108,7 +113,7 @@ struct Bm25ExecutorTest : public ::testing::Test {
clear_term(2, 1);
}
bool execute(feature_t exp_score) {
- return test.execute(exp_score);
+ return test.execute(exp_score, 0.000001);
}
void clear_term(uint32_t term_id, uint32_t field_id) {
auto* tfmd = match_data->getTermFieldMatchData(term_id, field_id);
@@ -123,8 +128,13 @@ struct Bm25ExecutorTest : public ::testing::Test {
tfmd->setFieldLength(field_length);
}
- feature_t get_score(feature_t num_occs, feature_t field_length, double avg_field_length = 10) const {
- return (num_occs * 2.2) / (num_occs + (1.2 * (0.25 + 0.75 * field_length / avg_field_length)));
+ double idf(uint32_t matching_doc_count) const {
+ return Bm25Executor::calculate_inverse_document_frequency(matching_doc_count, total_doc_count);
+ }
+
+ feature_t get_score(feature_t num_occs, feature_t field_length,
+ double inverse_doc_freq, double avg_field_length = 10) const {
+ return inverse_doc_freq * (num_occs * 2.2) / (num_occs + (1.2 * (0.25 + 0.75 * field_length / avg_field_length)));
}
};
@@ -132,7 +142,7 @@ TEST_F(Bm25ExecutorTest, score_is_calculated_for_a_single_term)
{
setup();
prepare_term(0, 0, 3, 20);
- EXPECT_TRUE(execute(get_score(3.0, 20)));
+ EXPECT_TRUE(execute(get_score(3.0, 20, idf(25))));
}
TEST_F(Bm25ExecutorTest, score_is_calculated_for_multiple_terms)
@@ -140,15 +150,16 @@ TEST_F(Bm25ExecutorTest, score_is_calculated_for_multiple_terms)
setup();
prepare_term(0, 0, 3, 20);
prepare_term(1, 0, 7, 5);
- EXPECT_TRUE(execute(get_score(3.0, 20) + get_score(7.0, 5.0)));
+ EXPECT_TRUE(execute(get_score(3.0, 20, idf(25)) + get_score(7.0, 5.0, idf(35))));
}
TEST_F(Bm25ExecutorTest, term_that_does_not_match_document_is_ignored)
{
setup();
prepare_term(0, 0, 3, 20);
- prepare_term(1, 0, 7, 5, 123);
- EXPECT_TRUE(execute(get_score(3.0, 20)));
+ uint32_t unmatched_doc_id = 123;
+ prepare_term(1, 0, 7, 5, unmatched_doc_id);
+ EXPECT_TRUE(execute(get_score(3.0, 20, idf(25))));
}
TEST_F(Bm25ExecutorTest, term_searching_another_field_is_ignored)
@@ -163,7 +174,17 @@ TEST_F(Bm25ExecutorTest, uses_average_field_length_from_shared_state_if_found)
test.getQueryEnv().getObjectStore().add("bm25.afl.foo", std::make_unique<AnyWrapper<double>>(15));
setup();
prepare_term(0, 0, 3, 20);
- EXPECT_TRUE(execute(get_score(3.0, 20, 15)));
+ EXPECT_TRUE(execute(get_score(3.0, 20, idf(25), 15)));
+}
+
+TEST_F(Bm25ExecutorTest, calculates_inverse_document_frequency)
+{
+ EXPECT_DOUBLE_EQ(std::log(1 + (99 + 0.5) / (1 + 0.5)),
+ Bm25Executor::calculate_inverse_document_frequency(1, 100));
+ EXPECT_DOUBLE_EQ(std::log(1 + (60 + 0.5) / (40 + 0.5)),
+ Bm25Executor::calculate_inverse_document_frequency(40, 100));
+ EXPECT_DOUBLE_EQ(std::log(1 + (0.5) / (100 + 0.5)),
+ Bm25Executor::calculate_inverse_document_frequency(100, 100));
}
GTEST_MAIN_RUN_ALL_TESTS()
diff --git a/searchlib/src/tests/features/prod_features.cpp b/searchlib/src/tests/features/prod_features.cpp
index 626a470cb5c..70250b05bf1 100644
--- a/searchlib/src/tests/features/prod_features.cpp
+++ b/searchlib/src/tests/features/prod_features.cpp
@@ -1968,8 +1968,10 @@ Test::testTerm()
.addField(FieldType::INDEX, CollectionType::SINGLE, "idx2") // field 1
.addField(FieldType::ATTRIBUTE, CollectionType::SINGLE, "attr"); // field 2
ft.getQueryEnv().getBuilder().addAllFields().setUniqueId(0);
- ft.getQueryEnv().getBuilder().addAllFields().setUniqueId(1).setWeight(search::query::Weight(200)).lookupField(0)->setDocFreq(0.5);
- ft.getQueryEnv().getBuilder().addAttributeNode("attr")->setUniqueId(2).setWeight(search::query::Weight(400)).lookupField(2)->setDocFreq(0.25);
+ ft.getQueryEnv().getBuilder().addAllFields().setUniqueId(1)
+ .setWeight(search::query::Weight(200)).lookupField(0)->setDocFreq(50, 100);
+ ft.getQueryEnv().getBuilder().addAttributeNode("attr")->setUniqueId(2)
+ .setWeight(search::query::Weight(400)).lookupField(2)->setDocFreq(25, 100);
// setup connectedness between term 1 and term 0
ft.getQueryEnv().getProperties().add("vespa.term.1.connexity", "0");
ft.getQueryEnv().getProperties().add("vespa.term.1.connexity", "0.7");
diff --git a/searchlib/src/tests/fef/termfieldmodel/termfieldmodel_test.cpp b/searchlib/src/tests/fef/termfieldmodel/termfieldmodel_test.cpp
index 9ed94c02287..3a0c334fbba 100644
--- a/searchlib/src/tests/fef/termfieldmodel/termfieldmodel_test.cpp
+++ b/searchlib/src/tests/fef/termfieldmodel/termfieldmodel_test.cpp
@@ -50,7 +50,7 @@ void testSetup(State &state) {
{
int i = 1;
for (SFR iter(state.term); iter.valid(); iter.next()) {
- iter.get().setDocFreq(0.25 * i++);
+ iter.get().setDocFreq(25 * i++, 100);
}
}
diff --git a/searchlib/src/vespa/searchlib/features/bm25_feature.cpp b/searchlib/src/vespa/searchlib/features/bm25_feature.cpp
index 0be3c2876f7..5a9e8455d73 100644
--- a/searchlib/src/vespa/searchlib/features/bm25_feature.cpp
+++ b/searchlib/src/vespa/searchlib/features/bm25_feature.cpp
@@ -4,6 +4,7 @@
#include <vespa/searchlib/fef/itermdata.h>
#include <vespa/searchlib/fef/itermfielddata.h>
#include <vespa/searchlib/fef/objectstore.h>
+#include <cmath>
#include <memory>
namespace search::features {
@@ -32,13 +33,22 @@ Bm25Executor::Bm25Executor(const fef::FieldInfo& field,
for (size_t j = 0; j < term->numFields(); ++j) {
const ITermFieldData& term_field = term->field(j);
if (field.id() == term_field.getFieldId()) {
- // TODO: Add proper calculation of IDF
- _terms.emplace_back(term_field.getHandle(MatchDataDetails::Cheap), 1.0);
+ // TODO: Add support for using significance instead of default idf if specified in the query
+ _terms.emplace_back(term_field.getHandle(MatchDataDetails::Cheap),
+ calculate_inverse_document_frequency(term_field.get_matching_doc_count(),
+ term_field.get_total_doc_count()));
}
}
}
}
+double
+Bm25Executor::calculate_inverse_document_frequency(uint32_t matching_doc_count, uint32_t total_doc_count)
+{
+ return std::log(1 + (static_cast<double>(total_doc_count - matching_doc_count + 0.5) /
+ static_cast<double>(matching_doc_count + 0.5)));
+}
+
void
Bm25Executor::handle_bind_match_data(const fef::MatchData& match_data)
{
diff --git a/searchlib/src/vespa/searchlib/features/bm25_feature.h b/searchlib/src/vespa/searchlib/features/bm25_feature.h
index 4b1576b57b3..4b5ea0214bc 100644
--- a/searchlib/src/vespa/searchlib/features/bm25_feature.h
+++ b/searchlib/src/vespa/searchlib/features/bm25_feature.h
@@ -33,6 +33,8 @@ public:
const fef::IQueryEnvironment& env,
double avg_field_length);
+ double static calculate_inverse_document_frequency(uint32_t matching_doc_count, uint32_t total_doc_count);
+
void handle_bind_match_data(const fef::MatchData& match_data) override;
void execute(uint32_t docId) override;
};
diff --git a/searchlib/src/vespa/searchlib/fef/itermfielddata.h b/searchlib/src/vespa/searchlib/fef/itermfielddata.h
index 80343db2250..6fb467ce25c 100644
--- a/searchlib/src/vespa/searchlib/fef/itermfielddata.h
+++ b/searchlib/src/vespa/searchlib/fef/itermfielddata.h
@@ -27,13 +27,26 @@ public:
**/
virtual uint32_t getFieldId() const = 0;
+
+ /**
+ * Returns the number of documents matching this term.
+ */
+ virtual uint32_t get_matching_doc_count() const = 0;
+
+ /**
+ * Returns the total number of documents in the corpus.
+ */
+ virtual uint32_t get_total_doc_count() const = 0;
+
/**
* Obtain the document frequency. This is a value between 0 and 1
* indicating the ratio of the matching documents to the corpus.
*
* @return document frequency
- **/
- virtual double getDocFreq() const = 0;
+ **/
+ double getDocFreq() const {
+ return (double)get_matching_doc_count() / (double)get_total_doc_count();
+ }
/**
* Obtain the match handle for this field,
diff --git a/searchlib/src/vespa/searchlib/fef/simpletermfielddata.cpp b/searchlib/src/vespa/searchlib/fef/simpletermfielddata.cpp
index d1edee7fd07..64906eed22e 100644
--- a/searchlib/src/vespa/searchlib/fef/simpletermfielddata.cpp
+++ b/searchlib/src/vespa/searchlib/fef/simpletermfielddata.cpp
@@ -2,22 +2,22 @@
#include "simpletermfielddata.h"
-namespace search {
-namespace fef {
+namespace search::fef {
SimpleTermFieldData::SimpleTermFieldData(uint32_t fieldId)
: _fieldId(fieldId),
- _docFreq(0),
+ _matching_doc_count(0),
+ _total_doc_count(1),
_handle(IllegalHandle)
{
}
SimpleTermFieldData::SimpleTermFieldData(const ITermFieldData &rhs)
: _fieldId(rhs.getFieldId()),
- _docFreq(rhs.getDocFreq()),
+ _matching_doc_count(rhs.get_matching_doc_count()),
+ _total_doc_count(rhs.get_total_doc_count()),
_handle(rhs.getHandle())
{
}
-} // namespace fef
-} // namespace search
+}
diff --git a/searchlib/src/vespa/searchlib/fef/simpletermfielddata.h b/searchlib/src/vespa/searchlib/fef/simpletermfielddata.h
index 6f0fbc9af64..d92d3a48f03 100644
--- a/searchlib/src/vespa/searchlib/fef/simpletermfielddata.h
+++ b/searchlib/src/vespa/searchlib/fef/simpletermfielddata.h
@@ -4,8 +4,7 @@
#include "itermfielddata.h"
-namespace search {
-namespace fef {
+namespace search::fef {
/**
* Information about a single field that is being searched for a term
@@ -17,7 +16,8 @@ class SimpleTermFieldData : public ITermFieldData
{
private:
uint32_t _fieldId;
- double _docFreq;
+ uint32_t _matching_doc_count;
+ uint32_t _total_doc_count;
TermFieldHandle _handle;
public:
@@ -33,28 +33,14 @@ public:
**/
SimpleTermFieldData(uint32_t fieldId);
- /**
- * Obtain the field id.
- *
- * @return field id
- **/
uint32_t getFieldId() const override final { return _fieldId; }
- /**
- * Obtain the document frequency.
- *
- * @return document frequency
- **/
- double getDocFreq() const override final { return _docFreq; }
+ uint32_t get_matching_doc_count() const override { return _matching_doc_count; }
+
+ uint32_t get_total_doc_count() const override { return _total_doc_count; }
using ITermFieldData::getHandle;
- /**
- * Obtain the match handle for this field,
- * requesting match data with the given details in the corresponding TermFieldMatchData.
- *
- * @return match handle (or IllegalHandle)
- **/
TermFieldHandle getHandle(MatchDataDetails requestedDetails) const override {
(void) requestedDetails;
return _handle;
@@ -62,20 +48,15 @@ public:
/**
* Sets the document frequency.
- *
- * @return this object (for chaining)
- * @param docFreq document frequency
**/
- SimpleTermFieldData &setDocFreq(double docFreq) {
- _docFreq = docFreq;
+ SimpleTermFieldData &setDocFreq(uint32_t matching_doc_count, uint32_t total_doc_count) {
+ _matching_doc_count = matching_doc_count;
+ _total_doc_count = total_doc_count;
return *this;
}
/**
* Sets the match handle for this field.
- *
- * @return this object (for chaining)
- * @param handle match handle
**/
SimpleTermFieldData &setHandle(TermFieldHandle handle) {
_handle = handle;
@@ -83,6 +64,5 @@ public:
}
};
-} // namespace fef
-} // namespace search
+}