diff options
author | Geir Storli <geirst@verizonmedia.com> | 2019-06-11 16:22:23 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2019-06-11 16:22:23 +0200 |
commit | 148d2dec420f6aca2278c1c29d3f46e6e680b746 (patch) | |
tree | 5bb9b986f607505feec9530c70f999c07f6ec742 | |
parent | e4e999fb5a70cad3f9f869ab0c915884f88e52c0 (diff) | |
parent | db04e49b9dcb3b671c351cd8cb4689f9d7217f01 (diff) |
Merge pull request #9751 from vespa-engine/geirst/average-field-length-in-bm25-feature
Use average field length (as provided by query environment) in bm25 r…
8 files changed, 120 insertions, 35 deletions
diff --git a/searchlib/src/tests/features/bm25/bm25_test.cpp b/searchlib/src/tests/features/bm25/bm25_test.cpp index 84bafcfa0ed..1a3895d7e28 100644 --- a/searchlib/src/tests/features/bm25/bm25_test.cpp +++ b/searchlib/src/tests/features/bm25/bm25_test.cpp @@ -11,6 +11,7 @@ using namespace search::features; using namespace search::fef; +using namespace search::fef::objectstore; using CollectionType = FieldInfo::CollectionType; using StringVector = std::vector<vespalib::string>; @@ -40,12 +41,13 @@ struct Bm25BlueprintTest : public ::testing::Test { EXPECT_FALSE(blueprint->setup(index_env, params)); } - void expect_setup_succeed(const StringVector& params) { + Blueprint::SP expect_setup_succeed(const StringVector& params) { auto blueprint = make_blueprint(); test::DummyDependencyHandler deps(*blueprint); EXPECT_TRUE(blueprint->setup(index_env, params)); EXPECT_EQ(0, deps.input.size()); EXPECT_EQ(StringVector({"score"}), deps.output); + return blueprint; } }; @@ -70,6 +72,15 @@ TEST_F(Bm25BlueprintTest, blueprint_setup_succeeds_for_index_field) expect_setup_succeed({"iws"}); } +TEST_F(Bm25BlueprintTest, blueprint_can_prepare_shared_state_with_average_field_length) +{ + auto blueprint = expect_setup_succeed({"is"}); + test::QueryEnvironment query_env; + query_env.get_avg_field_lengths()["is"] = 10; + ObjectStore store; + blueprint->prepareSharedState(query_env, store); + EXPECT_DOUBLE_EQ(10, as_value<double>(*store.get("bm25.afl.is"))); +} struct Bm25ExecutorTest : public ::testing::Test { BlueprintFactory factory; @@ -87,9 +98,10 @@ struct Bm25ExecutorTest : public ::testing::Test { test.getQueryEnv().getBuilder().addIndexNode({"foo"}); test.getQueryEnv().getBuilder().addIndexNode({"foo"}); test.getQueryEnv().getBuilder().addIndexNode({"bar"}); - + test.getQueryEnv().getBuilder().set_avg_field_length("foo", 10); + } + void setup() { EXPECT_TRUE(test.setup()); - match_data = test.createMatchDataBuilder(); clear_term(0, 0); clear_term(1, 0); @@ -111,19 +123,21 @@ struct Bm25ExecutorTest : public ::testing::Test { tfmd->setFieldLength(field_length); } - feature_t get_score(feature_t num_occs, feature_t field_length) const { - return (num_occs * 2.2) / (num_occs + (1.2 * (0.25 + 0.75 * field_length / 10.0))); + feature_t get_score(feature_t num_occs, feature_t field_length, double avg_field_length = 10) const { + return (num_occs * 2.2) / (num_occs + (1.2 * (0.25 + 0.75 * field_length / avg_field_length))); } }; TEST_F(Bm25ExecutorTest, score_is_calculated_for_a_single_term) { + setup(); prepare_term(0, 0, 3, 20); EXPECT_TRUE(execute(get_score(3.0, 20))); } TEST_F(Bm25ExecutorTest, score_is_calculated_for_multiple_terms) { + setup(); prepare_term(0, 0, 3, 20); prepare_term(1, 0, 7, 5); EXPECT_TRUE(execute(get_score(3.0, 20) + get_score(7.0, 5.0))); @@ -131,6 +145,7 @@ TEST_F(Bm25ExecutorTest, score_is_calculated_for_multiple_terms) TEST_F(Bm25ExecutorTest, term_that_does_not_match_document_is_ignored) { + setup(); prepare_term(0, 0, 3, 20); prepare_term(1, 0, 7, 5, 123); EXPECT_TRUE(execute(get_score(3.0, 20))); @@ -138,8 +153,17 @@ TEST_F(Bm25ExecutorTest, term_that_does_not_match_document_is_ignored) TEST_F(Bm25ExecutorTest, term_searching_another_field_is_ignored) { + setup(); prepare_term(2, 1, 3, 20); EXPECT_TRUE(execute(0.0)); } +TEST_F(Bm25ExecutorTest, uses_average_field_length_from_shared_state_if_found) +{ + test.getQueryEnv().getObjectStore().add("bm25.afl.foo", std::make_unique<AnyWrapper<double>>(15)); + setup(); + prepare_term(0, 0, 3, 20); + EXPECT_TRUE(execute(get_score(3.0, 20, 15))); +} + GTEST_MAIN_RUN_ALL_TESTS() diff --git a/searchlib/src/vespa/searchlib/features/bm25_feature.cpp b/searchlib/src/vespa/searchlib/features/bm25_feature.cpp index a9430db09c3..0be3c2876f7 100644 --- a/searchlib/src/vespa/searchlib/features/bm25_feature.cpp +++ b/searchlib/src/vespa/searchlib/features/bm25_feature.cpp @@ -3,26 +3,29 @@ #include "bm25_feature.h" #include <vespa/searchlib/fef/itermdata.h> #include <vespa/searchlib/fef/itermfielddata.h> +#include <vespa/searchlib/fef/objectstore.h> #include <memory> namespace search::features { +using fef::AnyWrapper; using fef::Blueprint; using fef::FeatureExecutor; using fef::FieldInfo; using fef::ITermData; using fef::ITermFieldData; using fef::MatchDataDetails; +using fef::objectstore::as_value; Bm25Executor::Bm25Executor(const fef::FieldInfo& field, - const fef::IQueryEnvironment& env) + const fef::IQueryEnvironment& env, + double avg_field_length) : FeatureExecutor(), _terms(), - _avg_field_length(10), + _avg_field_length(avg_field_length), _k1_param(1.2), _b_param(0.75) { - // TODO: Don't use hard coded avg_field_length // TODO: Add support for setting k1 and b for (size_t i = 0; i < env.getNumTerms(); ++i) { const ITermData* term = env.getTerm(i); @@ -93,10 +96,33 @@ Bm25Blueprint::setup(const fef::IIndexEnvironment& env, const fef::ParameterList return (_field != nullptr); } +namespace { + +vespalib::string +make_avg_field_length_key(const vespalib::string& base_name, const vespalib::string& field_name) +{ + return base_name + ".afl." + field_name; +} + +} + +void +Bm25Blueprint::prepareSharedState(const fef::IQueryEnvironment& env, fef::IObjectStore& store) const +{ + vespalib::string key = make_avg_field_length_key(getBaseName(), _field->name()); + if (store.get(key) == nullptr) { + store.add(key, std::make_unique<AnyWrapper<double>>(env.get_average_field_length(_field->name()))); + } +} + fef::FeatureExecutor& Bm25Blueprint::createExecutor(const fef::IQueryEnvironment& env, vespalib::Stash& stash) const { - return stash.create<Bm25Executor>(*_field, env); + const auto* lookup_result = env.getObjectStore().get(make_avg_field_length_key(getBaseName(), _field->name())); + double avg_field_length = lookup_result != nullptr ? + as_value<double>(*lookup_result) : + env.get_average_field_length(_field->name()); + return stash.create<Bm25Executor>(*_field, env, avg_field_length); } } diff --git a/searchlib/src/vespa/searchlib/features/bm25_feature.h b/searchlib/src/vespa/searchlib/features/bm25_feature.h index 457cfea4c87..4b1576b57b3 100644 --- a/searchlib/src/vespa/searchlib/features/bm25_feature.h +++ b/searchlib/src/vespa/searchlib/features/bm25_feature.h @@ -30,7 +30,8 @@ private: public: Bm25Executor(const fef::FieldInfo& field, - const fef::IQueryEnvironment& env); + const fef::IQueryEnvironment& env, + double avg_field_length); void handle_bind_match_data(const fef::MatchData& match_data) override; void execute(uint32_t docId) override; @@ -53,6 +54,7 @@ public: return fef::ParameterDescriptions().desc().indexField(fef::ParameterCollection::ANY); } bool setup(const fef::IIndexEnvironment& env, const fef::ParameterList& params) override; + void prepareSharedState(const fef::IQueryEnvironment& env, fef::IObjectStore& store) const override; fef::FeatureExecutor& createExecutor(const fef::IQueryEnvironment& env, vespalib::Stash& stash) const override; }; diff --git a/searchlib/src/vespa/searchlib/fef/objectstore.h b/searchlib/src/vespa/searchlib/fef/objectstore.h index 49176afa3c9..2debcd277e9 100644 --- a/searchlib/src/vespa/searchlib/fef/objectstore.h +++ b/searchlib/src/vespa/searchlib/fef/objectstore.h @@ -2,9 +2,13 @@ #pragma once #include <vespa/vespalib/stllike/hash_map.h> +#include <cassert> namespace search::fef { +/** + * Top level interface for things to store in an IObjectStore. + */ class Anything { public: @@ -12,6 +16,9 @@ public: virtual ~Anything() { } }; +/** + * Implementation of the Anything interface that wraps a value of the given type. + */ template<typename T> class AnyWrapper : public Anything { @@ -22,6 +29,9 @@ private: T _value; }; +/** + * Interface for a key value store of Anything instances. + */ class IObjectStore { public: @@ -30,6 +40,9 @@ public: virtual const Anything * get(const vespalib::string & key) const = 0; }; +/** + * Object store implementation on top of a hash map. + */ class ObjectStore : public IObjectStore { public: @@ -42,4 +55,20 @@ private: ObjectMap _objectMap; }; +namespace objectstore { + +/** + * Utility function that gets the value stored in an Anything instance (via AnyWrapper). + */ +template<typename T> +const T & +as_value(const Anything &val) { + using WrapperType = AnyWrapper<T>; + const auto *wrapper = dynamic_cast<const WrapperType *>(&val); + assert(wrapper != nullptr); + return wrapper->getValue(); +} + +} + } diff --git a/searchlib/src/vespa/searchlib/fef/test/queryenvironment.cpp b/searchlib/src/vespa/searchlib/fef/test/queryenvironment.cpp index ee305dcff55..4697675c071 100644 --- a/searchlib/src/vespa/searchlib/fef/test/queryenvironment.cpp +++ b/searchlib/src/vespa/searchlib/fef/test/queryenvironment.cpp @@ -2,21 +2,17 @@ #include "queryenvironment.h" -namespace search { -namespace fef { -namespace test { +namespace search::fef::test { QueryEnvironment::QueryEnvironment(IndexEnvironment *env) : _indexEnv(env), _terms(), _properties(), _location(), - _attrCtx((env == NULL) ? attribute::IAttributeContext::UP() : env->getAttributeMap().createContext()) + _attrCtx((env == nullptr) ? attribute::IAttributeContext::UP() : env->getAttributeMap().createContext()) { } QueryEnvironment::~QueryEnvironment() { } -} // namespace test -} // namespace fef -} // namespace search +} diff --git a/searchlib/src/vespa/searchlib/fef/test/queryenvironment.h b/searchlib/src/vespa/searchlib/fef/test/queryenvironment.h index 4d7a92586e6..40898281794 100644 --- a/searchlib/src/vespa/searchlib/fef/test/queryenvironment.h +++ b/searchlib/src/vespa/searchlib/fef/test/queryenvironment.h @@ -6,10 +6,9 @@ #include <vespa/searchlib/fef/iqueryenvironment.h> #include <vespa/searchlib/fef/location.h> #include <vespa/searchlib/fef/simpletermdata.h> +#include <unordered_map> -namespace search { -namespace fef { -namespace test { +namespace search::fef::test { /** * Implementation of the IQueryEnvironment interface used for testing. @@ -25,6 +24,7 @@ private: Properties _properties; Location _location; search::attribute::IAttributeContext::UP _attrCtx; + std::unordered_map<std::string, double> _avg_field_lengths; public: /** @@ -40,7 +40,11 @@ public: const ITermData *getTerm(uint32_t idx) const override { return idx < _terms.size() ? &_terms[idx] : NULL; } const Location & getLocation() const override { return _location; } const search::attribute::IAttributeContext &getAttributeContext() const override { return *_attrCtx; } - double get_average_field_length(const vespalib::string &) const override { + double get_average_field_length(const vespalib::string& field_name) const override { + auto itr = _avg_field_lengths.find(field_name); + if (itr != _avg_field_lengths.end()) { + return itr->second; + } return 1.0; } const IIndexEnvironment &getIndexEnvironment() const override { assert(_indexEnv != NULL); return *_indexEnv; } @@ -79,9 +83,9 @@ public: /** Returns a reference to the location of this. */ Location & getLocation() { return _location; } + + std::unordered_map<std::string, double>& get_avg_field_lengths() { return _avg_field_lengths; } }; -} // namespace test -} // namespace fef -} // namespace search +} diff --git a/searchlib/src/vespa/searchlib/fef/test/queryenvironmentbuilder.cpp b/searchlib/src/vespa/searchlib/fef/test/queryenvironmentbuilder.cpp index 2d9fb998869..67a2eaf5677 100644 --- a/searchlib/src/vespa/searchlib/fef/test/queryenvironmentbuilder.cpp +++ b/searchlib/src/vespa/searchlib/fef/test/queryenvironmentbuilder.cpp @@ -2,16 +2,13 @@ #include "queryenvironmentbuilder.h" -namespace search { -namespace fef { -namespace test { +namespace search::fef::test { QueryEnvironmentBuilder::QueryEnvironmentBuilder(QueryEnvironment &env, MatchDataLayout &layout) : _queryEnv(env), _layout(layout) { - // empty } QueryEnvironmentBuilder::~QueryEnvironmentBuilder() { } @@ -39,8 +36,8 @@ QueryEnvironmentBuilder::addIndexNode(const std::vector<vespalib::string> &field td.setWeight(search::query::Weight(100)); for (uint32_t i = 0; i < fieldNames.size(); ++i) { const FieldInfo *info = _queryEnv.getIndexEnv()->getFieldByName(fieldNames[i]); - if (info == NULL || info->type() != FieldType::INDEX) { - return NULL; + if (info == nullptr || info->type() != FieldType::INDEX) { + return nullptr; } SimpleTermFieldData &tfd = td.addField(info->id()); tfd.setHandle(_layout.allocTermField(tfd.getFieldId())); @@ -52,8 +49,8 @@ SimpleTermData * QueryEnvironmentBuilder::addAttributeNode(const vespalib::string &attrName) { const FieldInfo *info = _queryEnv.getIndexEnv()->getFieldByName(attrName); - if (info == NULL || info->type() != FieldType::ATTRIBUTE) { - return NULL; + if (info == nullptr || info->type() != FieldType::ATTRIBUTE) { + return nullptr; } _queryEnv.getTerms().push_back(SimpleTermData()); SimpleTermData &td = _queryEnv.getTerms().back(); @@ -63,6 +60,11 @@ QueryEnvironmentBuilder::addAttributeNode(const vespalib::string &attrName) return &td; } -} // namespace test -} // namespace fef -} // namespace search +QueryEnvironmentBuilder& +QueryEnvironmentBuilder::set_avg_field_length(const vespalib::string& field_name, double avg_field_length) +{ + _queryEnv.get_avg_field_lengths()[field_name] = avg_field_length; + return *this; +} + +} diff --git a/searchlib/src/vespa/searchlib/fef/test/queryenvironmentbuilder.h b/searchlib/src/vespa/searchlib/fef/test/queryenvironmentbuilder.h index 98aed323f9a..36a63b2a9a2 100644 --- a/searchlib/src/vespa/searchlib/fef/test/queryenvironmentbuilder.h +++ b/searchlib/src/vespa/searchlib/fef/test/queryenvironmentbuilder.h @@ -57,6 +57,8 @@ public: /** Returns a const reference to the match data layout of this. */ const MatchDataLayout &getLayout() const { return _layout; } + QueryEnvironmentBuilder& set_avg_field_length(const vespalib::string& field_name, double avg_field_length); + private: QueryEnvironmentBuilder(const QueryEnvironmentBuilder &); // hide QueryEnvironmentBuilder & operator=(const QueryEnvironmentBuilder &); // hide |