diff options
Diffstat (limited to 'searchlib/src/tests/features/bm25/bm25_test.cpp')
-rw-r--r-- | searchlib/src/tests/features/bm25/bm25_test.cpp | 120 |
1 files changed, 107 insertions, 13 deletions
diff --git a/searchlib/src/tests/features/bm25/bm25_test.cpp b/searchlib/src/tests/features/bm25/bm25_test.cpp index 84bafcfa0ed..eb2f46650a6 100644 --- a/searchlib/src/tests/features/bm25/bm25_test.cpp +++ b/searchlib/src/tests/features/bm25/bm25_test.cpp @@ -11,6 +11,7 @@ using namespace search::features; using namespace search::fef; +using namespace search::fef::objectstore; using CollectionType = FieldInfo::CollectionType; using StringVector = std::vector<vespalib::string>; @@ -40,12 +41,13 @@ struct Bm25BlueprintTest : public ::testing::Test { EXPECT_FALSE(blueprint->setup(index_env, params)); } - void expect_setup_succeed(const StringVector& params) { + Blueprint::SP expect_setup_succeed(const StringVector& params) { auto blueprint = make_blueprint(); test::DummyDependencyHandler deps(*blueprint); EXPECT_TRUE(blueprint->setup(index_env, params)); EXPECT_EQ(0, deps.input.size()); EXPECT_EQ(StringVector({"score"}), deps.output); + return blueprint; } }; @@ -63,6 +65,18 @@ TEST_F(Bm25BlueprintTest, blueprint_setup_fails_when_parameter_list_is_not_valid expect_setup_fail({"is", "ia"}); // wrong parameter number } +TEST_F(Bm25BlueprintTest, blueprint_setup_fails_when_k1_param_is_malformed) +{ + index_env.getProperties().add("bm25(is).k1", "malformed"); + expect_setup_fail({"is"}); +} + +TEST_F(Bm25BlueprintTest, blueprint_setup_fails_when_b_param_is_malformed) +{ + index_env.getProperties().add("bm25(is).b", "malformed"); + expect_setup_fail({"is"}); +} + TEST_F(Bm25BlueprintTest, blueprint_setup_succeeds_for_index_field) { expect_setup_succeed({"is"}); @@ -70,11 +84,40 @@ TEST_F(Bm25BlueprintTest, blueprint_setup_succeeds_for_index_field) expect_setup_succeed({"iws"}); } +TEST_F(Bm25BlueprintTest, blueprint_can_prepare_shared_state_with_average_field_length) +{ + auto blueprint = expect_setup_succeed({"is"}); + test::QueryEnvironment query_env; + query_env.get_avg_field_lengths()["is"] = 10; + ObjectStore store; + blueprint->prepareSharedState(query_env, store); + EXPECT_DOUBLE_EQ(10, as_value<double>(*store.get("bm25.afl.is"))); +} + +struct Scorer { + + double avg_field_length; + double k1_param; + double b_param; + + Scorer() : + avg_field_length(10), + k1_param(1.2), + b_param(0.75) + {} + + feature_t score(feature_t num_occs, feature_t field_length, double inverse_doc_freq) const { + return inverse_doc_freq * (num_occs * (1 + k1_param)) / + (num_occs + (k1_param * ((1 - b_param) + b_param * field_length / avg_field_length))); + } +}; struct Bm25ExecutorTest : public ::testing::Test { BlueprintFactory factory; FtFeatureTest test; test::MatchDataBuilder::UP match_data; + Scorer scorer; + static constexpr uint32_t total_doc_count = 100; Bm25ExecutorTest() : factory(), @@ -84,19 +127,24 @@ struct Bm25ExecutorTest : public ::testing::Test { setup_search_features(factory); test.getIndexEnv().getBuilder().addField(FieldType::INDEX, CollectionType::SINGLE, "foo"); test.getIndexEnv().getBuilder().addField(FieldType::INDEX, CollectionType::SINGLE, "bar"); - test.getQueryEnv().getBuilder().addIndexNode({"foo"}); - test.getQueryEnv().getBuilder().addIndexNode({"foo"}); - test.getQueryEnv().getBuilder().addIndexNode({"bar"}); - + add_query_term("foo", 25); + add_query_term("foo", 35); + add_query_term("bar", 45); + test.getQueryEnv().getBuilder().set_avg_field_length("foo", 10); + } + void add_query_term(const vespalib::string& field_name, uint32_t matching_doc_count) { + auto* term = test.getQueryEnv().getBuilder().addIndexNode({field_name}); + term->field(0).setDocFreq(matching_doc_count, total_doc_count); + } + void setup() { EXPECT_TRUE(test.setup()); - match_data = test.createMatchDataBuilder(); clear_term(0, 0); clear_term(1, 0); clear_term(2, 1); } bool execute(feature_t exp_score) { - return test.execute(exp_score); + return test.execute(exp_score, 0.000001); } void clear_term(uint32_t term_id, uint32_t field_id) { auto* tfmd = match_data->getTermFieldMatchData(term_id, field_id); @@ -111,35 +159,81 @@ struct Bm25ExecutorTest : public ::testing::Test { tfmd->setFieldLength(field_length); } - feature_t get_score(feature_t num_occs, feature_t field_length) const { - return (num_occs * 2.2) / (num_occs + (1.2 * (0.25 + 0.75 * field_length / 10.0))); + double idf(uint32_t matching_doc_count) const { + return Bm25Executor::calculate_inverse_document_frequency(matching_doc_count, total_doc_count); + } + + feature_t score(feature_t num_occs, feature_t field_length, double inverse_doc_freq) const { + return scorer.score(num_occs, field_length, inverse_doc_freq); } }; TEST_F(Bm25ExecutorTest, score_is_calculated_for_a_single_term) { + setup(); prepare_term(0, 0, 3, 20); - EXPECT_TRUE(execute(get_score(3.0, 20))); + EXPECT_TRUE(execute(score(3.0, 20, idf(25)))); } TEST_F(Bm25ExecutorTest, score_is_calculated_for_multiple_terms) { + setup(); prepare_term(0, 0, 3, 20); prepare_term(1, 0, 7, 5); - EXPECT_TRUE(execute(get_score(3.0, 20) + get_score(7.0, 5.0))); + EXPECT_TRUE(execute(score(3.0, 20, idf(25)) + score(7.0, 5.0, idf(35)))); } TEST_F(Bm25ExecutorTest, term_that_does_not_match_document_is_ignored) { + setup(); prepare_term(0, 0, 3, 20); - prepare_term(1, 0, 7, 5, 123); - EXPECT_TRUE(execute(get_score(3.0, 20))); + uint32_t unmatched_doc_id = 123; + prepare_term(1, 0, 7, 5, unmatched_doc_id); + EXPECT_TRUE(execute(score(3.0, 20, idf(25)))); } TEST_F(Bm25ExecutorTest, term_searching_another_field_is_ignored) { + setup(); prepare_term(2, 1, 3, 20); EXPECT_TRUE(execute(0.0)); } +TEST_F(Bm25ExecutorTest, uses_average_field_length_from_shared_state_if_found) +{ + test.getQueryEnv().getObjectStore().add("bm25.afl.foo", std::make_unique<AnyWrapper<double>>(15)); + setup(); + prepare_term(0, 0, 3, 20); + scorer.avg_field_length = 15; + EXPECT_TRUE(execute(score(3.0, 20, idf(25)))); +} + +TEST_F(Bm25ExecutorTest, calculates_inverse_document_frequency) +{ + EXPECT_DOUBLE_EQ(std::log(1 + (99 + 0.5) / (1 + 0.5)), + Bm25Executor::calculate_inverse_document_frequency(1, 100)); + EXPECT_DOUBLE_EQ(std::log(1 + (60 + 0.5) / (40 + 0.5)), + Bm25Executor::calculate_inverse_document_frequency(40, 100)); + EXPECT_DOUBLE_EQ(std::log(1 + (0.5) / (100 + 0.5)), + Bm25Executor::calculate_inverse_document_frequency(100, 100)); +} + +TEST_F(Bm25ExecutorTest, k1_param_can_be_overriden) +{ + test.getIndexEnv().getProperties().add("bm25(foo).k1", "2.5"); + setup(); + prepare_term(0, 0, 3, 20); + scorer.k1_param = 2.5; + EXPECT_TRUE(execute(score(3.0, 20, idf(25)))); +} + +TEST_F(Bm25ExecutorTest, b_param_can_be_overriden) +{ + test.getIndexEnv().getProperties().add("bm25(foo).b", "0.9"); + setup(); + prepare_term(0, 0, 3, 20); + scorer.b_param = 0.9; + EXPECT_TRUE(execute(score(3.0, 20, idf(25)))); +} + GTEST_MAIN_RUN_ALL_TESTS() |