diff options
-rw-r--r-- | searchlib/CMakeLists.txt | 1 | ||||
-rw-r--r-- | searchlib/src/tests/features/bm25/CMakeLists.txt | 11 | ||||
-rw-r--r-- | searchlib/src/tests/features/bm25/bm25_test.cpp | 145 | ||||
-rw-r--r-- | searchlib/src/vespa/searchlib/features/CMakeLists.txt | 1 | ||||
-rw-r--r-- | searchlib/src/vespa/searchlib/features/bm25_feature.cpp | 101 | ||||
-rw-r--r-- | searchlib/src/vespa/searchlib/features/bm25_feature.h | 59 | ||||
-rw-r--r-- | searchlib/src/vespa/searchlib/features/setup.cpp | 16 |
7 files changed, 327 insertions, 7 deletions
diff --git a/searchlib/CMakeLists.txt b/searchlib/CMakeLists.txt index 66408b1d7d7..10231fb4634 100644 --- a/searchlib/CMakeLists.txt +++ b/searchlib/CMakeLists.txt @@ -133,6 +133,7 @@ vespa_define_module( src/tests/expression/attributenode src/tests/features src/tests/features/beta + src/tests/features/bm25 src/tests/features/constant src/tests/features/element_completeness src/tests/features/element_similarity_feature diff --git a/searchlib/src/tests/features/bm25/CMakeLists.txt b/searchlib/src/tests/features/bm25/CMakeLists.txt new file mode 100644 index 00000000000..3f9b92684f8 --- /dev/null +++ b/searchlib/src/tests/features/bm25/CMakeLists.txt @@ -0,0 +1,11 @@ +# Copyright 2019 Oath Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. + +vespa_add_executable(searchlib_features_bm25_test_app TEST + SOURCES + bm25_test.cpp + DEPENDS + searchlib + searchlib_test + gtest +) +vespa_add_test(NAME searchlib_features_bm25_test_app COMMAND searchlib_features_bm25_test_app) diff --git a/searchlib/src/tests/features/bm25/bm25_test.cpp b/searchlib/src/tests/features/bm25/bm25_test.cpp new file mode 100644 index 00000000000..84bafcfa0ed --- /dev/null +++ b/searchlib/src/tests/features/bm25/bm25_test.cpp @@ -0,0 +1,145 @@ +// Copyright 2019 Oath Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. + +#include <vespa/searchlib/features/bm25_feature.h> +#include <vespa/searchlib/features/setup.h> +#include <vespa/searchlib/fef/blueprintfactory.h> +#include <vespa/searchlib/fef/test/dummy_dependency_handler.h> +#include <vespa/searchlib/fef/test/ftlib.h> +#include <vespa/searchlib/fef/test/indexenvironment.h> +#include <vespa/searchlib/fef/test/indexenvironmentbuilder.h> +#include <vespa/vespalib/gtest/gtest.h> + +using namespace search::features; +using namespace search::fef; +using CollectionType = FieldInfo::CollectionType; +using StringVector = std::vector<vespalib::string>; + +struct Bm25BlueprintTest : public ::testing::Test { + BlueprintFactory factory; + test::IndexEnvironment index_env; + + Bm25BlueprintTest() + : factory(), + index_env() + { + setup_search_features(factory); + test::IndexEnvironmentBuilder builder(index_env); + builder.addField(FieldType::INDEX, CollectionType::SINGLE, "is"); + builder.addField(FieldType::INDEX, CollectionType::ARRAY, "ia"); + builder.addField(FieldType::INDEX, CollectionType::WEIGHTEDSET, "iws"); + builder.addField(FieldType::ATTRIBUTE, CollectionType::SINGLE, "as"); + } + + Blueprint::SP make_blueprint() const { + return factory.createBlueprint("bm25"); + } + + void expect_setup_fail(const StringVector& params) { + auto blueprint = make_blueprint(); + test::DummyDependencyHandler deps(*blueprint); + EXPECT_FALSE(blueprint->setup(index_env, params)); + } + + void expect_setup_succeed(const StringVector& params) { + auto blueprint = make_blueprint(); + test::DummyDependencyHandler deps(*blueprint); + EXPECT_TRUE(blueprint->setup(index_env, params)); + EXPECT_EQ(0, deps.input.size()); + EXPECT_EQ(StringVector({"score"}), deps.output); + } +}; + +TEST_F(Bm25BlueprintTest, blueprint_can_be_created_from_factory) +{ + auto bp = factory.createBlueprint("bm25"); + EXPECT_TRUE(bp.get() != nullptr); + EXPECT_TRUE(dynamic_cast<Bm25Blueprint*>(bp.get()) != nullptr); +} + +TEST_F(Bm25BlueprintTest, blueprint_setup_fails_when_parameter_list_is_not_valid) +{ + expect_setup_fail({}); // wrong parameter number + expect_setup_fail({"as"}); // 'as' is an attribute + expect_setup_fail({"is", "ia"}); // wrong parameter number +} + +TEST_F(Bm25BlueprintTest, blueprint_setup_succeeds_for_index_field) +{ + expect_setup_succeed({"is"}); + expect_setup_succeed({"ia"}); + expect_setup_succeed({"iws"}); +} + + +struct Bm25ExecutorTest : public ::testing::Test { + BlueprintFactory factory; + FtFeatureTest test; + test::MatchDataBuilder::UP match_data; + + Bm25ExecutorTest() + : factory(), + test(factory, "bm25(foo)"), + match_data() + { + setup_search_features(factory); + test.getIndexEnv().getBuilder().addField(FieldType::INDEX, CollectionType::SINGLE, "foo"); + test.getIndexEnv().getBuilder().addField(FieldType::INDEX, CollectionType::SINGLE, "bar"); + test.getQueryEnv().getBuilder().addIndexNode({"foo"}); + test.getQueryEnv().getBuilder().addIndexNode({"foo"}); + test.getQueryEnv().getBuilder().addIndexNode({"bar"}); + + EXPECT_TRUE(test.setup()); + + match_data = test.createMatchDataBuilder(); + clear_term(0, 0); + clear_term(1, 0); + clear_term(2, 1); + } + bool execute(feature_t exp_score) { + return test.execute(exp_score); + } + void clear_term(uint32_t term_id, uint32_t field_id) { + auto* tfmd = match_data->getTermFieldMatchData(term_id, field_id); + ASSERT_TRUE(tfmd != nullptr); + tfmd->reset(123); + } + void prepare_term(uint32_t term_id, uint32_t field_id, uint16_t num_occs, uint16_t field_length, uint32_t doc_id = 1) { + auto* tfmd = match_data->getTermFieldMatchData(term_id, field_id); + ASSERT_TRUE(tfmd != nullptr); + tfmd->reset(doc_id); + tfmd->setNumOccs(num_occs); + tfmd->setFieldLength(field_length); + } + + feature_t get_score(feature_t num_occs, feature_t field_length) const { + return (num_occs * 2.2) / (num_occs + (1.2 * (0.25 + 0.75 * field_length / 10.0))); + } +}; + +TEST_F(Bm25ExecutorTest, score_is_calculated_for_a_single_term) +{ + prepare_term(0, 0, 3, 20); + EXPECT_TRUE(execute(get_score(3.0, 20))); +} + +TEST_F(Bm25ExecutorTest, score_is_calculated_for_multiple_terms) +{ + prepare_term(0, 0, 3, 20); + prepare_term(1, 0, 7, 5); + EXPECT_TRUE(execute(get_score(3.0, 20) + get_score(7.0, 5.0))); +} + +TEST_F(Bm25ExecutorTest, term_that_does_not_match_document_is_ignored) +{ + prepare_term(0, 0, 3, 20); + prepare_term(1, 0, 7, 5, 123); + EXPECT_TRUE(execute(get_score(3.0, 20))); +} + +TEST_F(Bm25ExecutorTest, term_searching_another_field_is_ignored) +{ + prepare_term(2, 1, 3, 20); + EXPECT_TRUE(execute(0.0)); +} + +GTEST_MAIN_RUN_ALL_TESTS() diff --git a/searchlib/src/vespa/searchlib/features/CMakeLists.txt b/searchlib/src/vespa/searchlib/features/CMakeLists.txt index 16401a67424..727ace182eb 100644 --- a/searchlib/src/vespa/searchlib/features/CMakeLists.txt +++ b/searchlib/src/vespa/searchlib/features/CMakeLists.txt @@ -5,6 +5,7 @@ vespa_add_library(searchlib_features OBJECT array_parser.cpp attributefeature.cpp attributematchfeature.cpp + bm25_feature.cpp closenessfeature.cpp constant_feature.cpp debug_attribute_wait.cpp diff --git a/searchlib/src/vespa/searchlib/features/bm25_feature.cpp b/searchlib/src/vespa/searchlib/features/bm25_feature.cpp new file mode 100644 index 00000000000..58c45afba9c --- /dev/null +++ b/searchlib/src/vespa/searchlib/features/bm25_feature.cpp @@ -0,0 +1,101 @@ +// Copyright 2019 Oath Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. + +#include "bm25_feature.h" +#include <vespa/searchlib/fef/itermdata.h> +#include <vespa/searchlib/fef/itermfielddata.h> +#include <memory> + +namespace search::features { + +using fef::Blueprint; +using fef::FeatureExecutor; +using fef::FieldInfo; +using fef::ITermData; +using fef::ITermFieldData; + +Bm25Executor::Bm25Executor(const fef::FieldInfo& field, + const fef::IQueryEnvironment& env) + : FeatureExecutor(), + _terms(), + _avg_field_length(10), + _k1_param(1.2), + _b_param(0.75) +{ + // TODO: Don't use hard coded avg_field_length + // TODO: Add support for setting k1 and b + for (size_t i = 0; i < env.getNumTerms(); ++i) { + const ITermData* term = env.getTerm(i); + for (size_t j = 0; j < term->numFields(); ++j) { + const ITermFieldData& term_field = term->field(j); + if (field.id() == term_field.getFieldId()) { + // TODO: Add proper calculation of IDF + _terms.emplace_back(term_field.getHandle(), 1.0); + } + } + } +} + +void +Bm25Executor::handle_bind_match_data(const fef::MatchData& match_data) +{ + for (auto& term : _terms) { + term.tfmd = match_data.resolveTermField(term.handle); + } +} + +void +Bm25Executor::execute(uint32_t doc_id) +{ + feature_t score = 0; + for (const auto& term : _terms) { + if (term.tfmd->getDocId() == doc_id) { + feature_t num_occs = term.tfmd->getNumOccs(); + feature_t norm_field_length = ((feature_t)term.tfmd->getFieldLength()) / _avg_field_length; + + feature_t numerator = term.inverse_doc_freq * num_occs * (_k1_param + 1); + feature_t denominator = num_occs + (_k1_param * (1 - _b_param + (_b_param * norm_field_length))); + + score += numerator / denominator; + } + } + outputs().set_number(0, score); +} + + +Bm25Blueprint::Bm25Blueprint() + : Blueprint("bm25"), + _field(nullptr) +{ +} + +void +Bm25Blueprint::visitDumpFeatures(const fef::IIndexEnvironment& env, fef::IDumpFeatureVisitor& visitor) const +{ + (void) env; + (void) visitor; + // TODO: Implement +} + +fef::Blueprint::UP +Bm25Blueprint::createInstance() const +{ + return std::make_unique<Bm25Blueprint>(); +} + +bool +Bm25Blueprint::setup(const fef::IIndexEnvironment& env, const fef::ParameterList& params) +{ + const auto& field_name = params[0].getValue(); + _field = env.getFieldByName(field_name); + + describeOutput("score", "The bm25 score for all terms searching in the given index field"); + return (_field != nullptr); +} + +fef::FeatureExecutor& +Bm25Blueprint::createExecutor(const fef::IQueryEnvironment& env, vespalib::Stash& stash) const +{ + return stash.create<Bm25Executor>(*_field, env); +} + +} diff --git a/searchlib/src/vespa/searchlib/features/bm25_feature.h b/searchlib/src/vespa/searchlib/features/bm25_feature.h new file mode 100644 index 00000000000..457cfea4c87 --- /dev/null +++ b/searchlib/src/vespa/searchlib/features/bm25_feature.h @@ -0,0 +1,59 @@ +// Copyright 2019 Oath Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. + +#include <vespa/searchlib/fef/blueprint.h> +#include <vespa/searchlib/fef/featureexecutor.h> + +namespace search::features { + +/** + * Executor for the BM25 ranking algorithm over a single index field. + */ +class Bm25Executor : public fef::FeatureExecutor { +private: + struct QueryTerm { + fef::TermFieldHandle handle; + const fef::TermFieldMatchData* tfmd; + double inverse_doc_freq; + QueryTerm(fef::TermFieldHandle handle_, double inverse_doc_freq_) + : handle(handle_), + tfmd(nullptr), + inverse_doc_freq(inverse_doc_freq_) + {} + }; + + using QueryTermVector = std::vector<QueryTerm>; + + QueryTermVector _terms; + double _avg_field_length; + double _k1_param; // Determines term frequency saturation characteristics. + double _b_param; // Adjusts the effects of the field length of the document matched compared to the average field length. + +public: + Bm25Executor(const fef::FieldInfo& field, + const fef::IQueryEnvironment& env); + + void handle_bind_match_data(const fef::MatchData& match_data) override; + void execute(uint32_t docId) override; +}; + + +/** + * Blueprint for the BM25 ranking algorithm over a single index field. + */ +class Bm25Blueprint : public fef::Blueprint { +private: + const fef::FieldInfo* _field; + +public: + Bm25Blueprint(); + + void visitDumpFeatures(const fef::IIndexEnvironment& env, fef::IDumpFeatureVisitor& visitor) const override; + fef::Blueprint::UP createInstance() const override; + fef::ParameterDescriptions getDescriptions() const override { + return fef::ParameterDescriptions().desc().indexField(fef::ParameterCollection::ANY); + } + bool setup(const fef::IIndexEnvironment& env, const fef::ParameterList& params) override; + fef::FeatureExecutor& createExecutor(const fef::IQueryEnvironment& env, vespalib::Stash& stash) const override; +}; + +} diff --git a/searchlib/src/vespa/searchlib/features/setup.cpp b/searchlib/src/vespa/searchlib/features/setup.cpp index a4a9689b971..e1e351d0726 100644 --- a/searchlib/src/vespa/searchlib/features/setup.cpp +++ b/searchlib/src/vespa/searchlib/features/setup.cpp @@ -4,7 +4,9 @@ #include "agefeature.h" #include "attributefeature.h" #include "attributematchfeature.h" +#include "bm25_feature.h" #include "closenessfeature.h" +#include "constant_feature.h" #include "debug_attribute_wait.h" #include "debug_wait.h" #include "distancefeature.h" @@ -36,9 +38,9 @@ #include "querycompletenessfeature.h" #include "queryfeature.h" #include "querytermcountfeature.h" -#include "randomfeature.h" #include "random_normal_feature.h" #include "random_normal_stable_feature.h" +#include "randomfeature.h" #include "rankingexpressionfeature.h" #include "raw_score_feature.h" #include "reverseproximityfeature.h" @@ -52,7 +54,6 @@ #include "terminfofeature.h" #include "text_similarity_feature.h" #include "valuefeature.h" -#include "constant_feature.h" #include "max_reduce_prod_join_replacer.h" #include <vespa/searchlib/features/rankingexpression/expression_replacer.h> @@ -69,27 +70,28 @@ void setup_search_features(fef::IBlueprintRegistry & registry) registry.addPrototype(Blueprint::SP(new AgeBlueprint())); registry.addPrototype(Blueprint::SP(new AttributeBlueprint())); registry.addPrototype(Blueprint::SP(new AttributeMatchBlueprint())); + registry.addPrototype(Blueprint::SP(new Bm25Blueprint())); registry.addPrototype(Blueprint::SP(new ClosenessBlueprint())); - registry.addPrototype(Blueprint::SP(new MatchCountBlueprint())); - registry.addPrototype(Blueprint::SP(new DistanceBlueprint())); - registry.addPrototype(Blueprint::SP(new DistanceToPathBlueprint())); registry.addPrototype(Blueprint::SP(new DebugAttributeWaitBlueprint())); registry.addPrototype(Blueprint::SP(new DebugWaitBlueprint())); + registry.addPrototype(Blueprint::SP(new DistanceBlueprint())); + registry.addPrototype(Blueprint::SP(new DistanceToPathBlueprint())); registry.addPrototype(Blueprint::SP(new DotProductBlueprint())); registry.addPrototype(Blueprint::SP(new ElementCompletenessBlueprint())); registry.addPrototype(Blueprint::SP(new ElementSimilarityBlueprint())); registry.addPrototype(Blueprint::SP(new EuclideanDistanceBlueprint())); registry.addPrototype(Blueprint::SP(new FieldInfoBlueprint())); - registry.addPrototype(Blueprint::SP(new FlowCompletenessBlueprint())); registry.addPrototype(Blueprint::SP(new FieldLengthBlueprint())); registry.addPrototype(Blueprint::SP(new FieldMatchBlueprint())); registry.addPrototype(Blueprint::SP(new FieldTermMatchBlueprint())); registry.addPrototype(Blueprint::SP(new FirstPhaseBlueprint())); + registry.addPrototype(Blueprint::SP(new FlowCompletenessBlueprint())); registry.addPrototype(Blueprint::SP(new ForeachBlueprint())); registry.addPrototype(Blueprint::SP(new FreshnessBlueprint())); registry.addPrototype(Blueprint::SP(new ItemRawScoreBlueprint())); - registry.addPrototype(Blueprint::SP(new MatchesBlueprint())); registry.addPrototype(Blueprint::SP(new MatchBlueprint())); + registry.addPrototype(Blueprint::SP(new MatchCountBlueprint())); + registry.addPrototype(Blueprint::SP(new MatchesBlueprint())); registry.addPrototype(Blueprint::SP(new NativeAttributeMatchBlueprint())); registry.addPrototype(Blueprint::SP(new NativeDotProductBlueprint())); registry.addPrototype(Blueprint::SP(new NativeFieldMatchBlueprint())); |