diff options
author | Tor Egge <Tor.Egge@broadpark.no> | 2020-06-23 10:09:55 +0200 |
---|---|---|
committer | Tor Egge <Tor.Egge@broadpark.no> | 2020-06-23 10:09:55 +0200 |
commit | 818001352928e28fbd27a86b6aadd640df62850e (patch) | |
tree | 039d323fcc8320cb3635ef8a7e1cf961a6cd73dd | |
parent | 749b7f7637c8b5c80dfe813d04c5301054b311c4 (diff) |
Check brute force limit in nearest neighbor blueprint.
3 files changed, 29 insertions, 10 deletions
diff --git a/searchlib/src/tests/attribute/tensorattribute/tensorattribute_test.cpp b/searchlib/src/tests/attribute/tensorattribute/tensorattribute_test.cpp index 6608959662f..c698a1d612b 100644 --- a/searchlib/src/tests/attribute/tensorattribute/tensorattribute_test.cpp +++ b/searchlib/src/tests/attribute/tensorattribute/tensorattribute_test.cpp @@ -887,13 +887,13 @@ public: return std::unique_ptr<QueryTensor>(tensor); } - std::unique_ptr<NearestNeighborBlueprint> make_blueprint() { + std::unique_ptr<NearestNeighborBlueprint> make_blueprint(double brute_force_limit = 0.05) { search::queryeval::FieldSpec field("foo", 0, 0); auto bp = std::make_unique<NearestNeighborBlueprint>( field, as_dense_tensor(), createDenseTensor(vec_2d(17, 42)), - 3, true, 5, 0.05); + 3, true, 5, brute_force_limit); EXPECT_EQUAL(11u, bp->getState().estimate().estHits); EXPECT_TRUE(bp->may_approximate()); return bp; @@ -938,4 +938,16 @@ TEST_F("NN blueprint handles weak filter", NearestNeighborBlueprintFixture) EXPECT_TRUE(bp->may_approximate()); } +TEST_F("NN blueprint handles strong filter triggering brute force search", NearestNeighborBlueprintFixture) +{ + auto bp = f.make_blueprint(0.2); + auto filter = search::BitVector::create(11); + filter->setBit(3); + filter->invalidateCachedCount(); + auto strong_filter = GlobalFilter::create(std::move(filter)); + bp->set_global_filter(*strong_filter); + EXPECT_EQUAL(11u, bp->getState().estimate().estHits); + EXPECT_FALSE(bp->may_approximate()); +} + TEST_MAIN() { TEST_RUN_ALL(); } diff --git a/searchlib/src/vespa/searchlib/queryeval/nearest_neighbor_blueprint.cpp b/searchlib/src/vespa/searchlib/queryeval/nearest_neighbor_blueprint.cpp index fcf8b78056d..3da20603e06 100644 --- a/searchlib/src/vespa/searchlib/queryeval/nearest_neighbor_blueprint.cpp +++ b/searchlib/src/vespa/searchlib/queryeval/nearest_neighbor_blueprint.cpp @@ -59,6 +59,7 @@ NearestNeighborBlueprint::NearestNeighborBlueprint(const queryeval::FieldSpec& f _query_tensor(std::move(query_tensor)), _target_num_hits(target_num_hits), _approximate(approximate), + _use_brute_force(false), _explore_additional_hits(explore_additional_hits), _brute_force_limit(brute_force_limit), _fallback_dist_fun(), @@ -98,15 +99,20 @@ NearestNeighborBlueprint::set_global_filter(const GlobalFilter &global_filter) if (_global_filter->has_filter()) { uint32_t max_hits = _global_filter->filter()->countTrueBits(); LOG(debug, "set_global_filter getNumDocs: %u / max_hits %u", est_hits, max_hits); - if (max_hits * 10 < est_hits) { - LOG(debug, "too many hits filtered out, consider using brute force implementation"); + double max_hit_ratio = static_cast<double>(max_hits) / est_hits; + if (max_hit_ratio < _brute_force_limit) { + _use_brute_force = true; + LOG(debug, "too many hits filtered out, using brute force implementation"); + } else { + est_hits = std::min(est_hits, max_hits); } - est_hits = std::min(est_hits, max_hits); } - est_hits = std::min(est_hits, _target_num_hits); - setEstimate(HitEstimate(est_hits, false)); - perform_top_k(); - LOG(debug, "perform_top_k found %zu hits", _found_hits.size()); + if (!_use_brute_force) { + est_hits = std::min(est_hits, _target_num_hits); + setEstimate(HitEstimate(est_hits, false)); + perform_top_k(); + LOG(debug, "perform_top_k found %zu hits", _found_hits.size()); + } } } diff --git a/searchlib/src/vespa/searchlib/queryeval/nearest_neighbor_blueprint.h b/searchlib/src/vespa/searchlib/queryeval/nearest_neighbor_blueprint.h index 8656e5b4bf2..8050d350af5 100644 --- a/searchlib/src/vespa/searchlib/queryeval/nearest_neighbor_blueprint.h +++ b/searchlib/src/vespa/searchlib/queryeval/nearest_neighbor_blueprint.h @@ -23,6 +23,7 @@ private: std::unique_ptr<vespalib::tensor::DenseTensorView> _query_tensor; uint32_t _target_num_hits; bool _approximate; + bool _use_brute_force; uint32_t _explore_additional_hits; double _brute_force_limit; search::tensor::DistanceFunction::UP _fallback_dist_fun; @@ -44,7 +45,7 @@ public: const vespalib::tensor::DenseTensorView& get_query_tensor() const { return *_query_tensor; } uint32_t get_target_num_hits() const { return _target_num_hits; } void set_global_filter(const GlobalFilter &global_filter) override; - bool may_approximate() const { return _approximate; } + bool may_approximate() const { return _approximate && !_use_brute_force; } std::unique_ptr<SearchIterator> createLeafSearch(const search::fef::TermFieldMatchDataArray& tfmda, bool strict) const override; |