summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorTor Egge <Tor.Egge@broadpark.no>2020-06-23 10:09:55 +0200
committerTor Egge <Tor.Egge@broadpark.no>2020-06-23 10:09:55 +0200
commit818001352928e28fbd27a86b6aadd640df62850e (patch)
tree039d323fcc8320cb3635ef8a7e1cf961a6cd73dd
parent749b7f7637c8b5c80dfe813d04c5301054b311c4 (diff)
Check brute force limit in nearest neighbor blueprint.
-rw-r--r--searchlib/src/tests/attribute/tensorattribute/tensorattribute_test.cpp16
-rw-r--r--searchlib/src/vespa/searchlib/queryeval/nearest_neighbor_blueprint.cpp20
-rw-r--r--searchlib/src/vespa/searchlib/queryeval/nearest_neighbor_blueprint.h3
3 files changed, 29 insertions, 10 deletions
diff --git a/searchlib/src/tests/attribute/tensorattribute/tensorattribute_test.cpp b/searchlib/src/tests/attribute/tensorattribute/tensorattribute_test.cpp
index 6608959662f..c698a1d612b 100644
--- a/searchlib/src/tests/attribute/tensorattribute/tensorattribute_test.cpp
+++ b/searchlib/src/tests/attribute/tensorattribute/tensorattribute_test.cpp
@@ -887,13 +887,13 @@ public:
return std::unique_ptr<QueryTensor>(tensor);
}
- std::unique_ptr<NearestNeighborBlueprint> make_blueprint() {
+ std::unique_ptr<NearestNeighborBlueprint> make_blueprint(double brute_force_limit = 0.05) {
search::queryeval::FieldSpec field("foo", 0, 0);
auto bp = std::make_unique<NearestNeighborBlueprint>(
field,
as_dense_tensor(),
createDenseTensor(vec_2d(17, 42)),
- 3, true, 5, 0.05);
+ 3, true, 5, brute_force_limit);
EXPECT_EQUAL(11u, bp->getState().estimate().estHits);
EXPECT_TRUE(bp->may_approximate());
return bp;
@@ -938,4 +938,16 @@ TEST_F("NN blueprint handles weak filter", NearestNeighborBlueprintFixture)
EXPECT_TRUE(bp->may_approximate());
}
+TEST_F("NN blueprint handles strong filter triggering brute force search", NearestNeighborBlueprintFixture)
+{
+ auto bp = f.make_blueprint(0.2);
+ auto filter = search::BitVector::create(11);
+ filter->setBit(3);
+ filter->invalidateCachedCount();
+ auto strong_filter = GlobalFilter::create(std::move(filter));
+ bp->set_global_filter(*strong_filter);
+ EXPECT_EQUAL(11u, bp->getState().estimate().estHits);
+ EXPECT_FALSE(bp->may_approximate());
+}
+
TEST_MAIN() { TEST_RUN_ALL(); }
diff --git a/searchlib/src/vespa/searchlib/queryeval/nearest_neighbor_blueprint.cpp b/searchlib/src/vespa/searchlib/queryeval/nearest_neighbor_blueprint.cpp
index fcf8b78056d..3da20603e06 100644
--- a/searchlib/src/vespa/searchlib/queryeval/nearest_neighbor_blueprint.cpp
+++ b/searchlib/src/vespa/searchlib/queryeval/nearest_neighbor_blueprint.cpp
@@ -59,6 +59,7 @@ NearestNeighborBlueprint::NearestNeighborBlueprint(const queryeval::FieldSpec& f
_query_tensor(std::move(query_tensor)),
_target_num_hits(target_num_hits),
_approximate(approximate),
+ _use_brute_force(false),
_explore_additional_hits(explore_additional_hits),
_brute_force_limit(brute_force_limit),
_fallback_dist_fun(),
@@ -98,15 +99,20 @@ NearestNeighborBlueprint::set_global_filter(const GlobalFilter &global_filter)
if (_global_filter->has_filter()) {
uint32_t max_hits = _global_filter->filter()->countTrueBits();
LOG(debug, "set_global_filter getNumDocs: %u / max_hits %u", est_hits, max_hits);
- if (max_hits * 10 < est_hits) {
- LOG(debug, "too many hits filtered out, consider using brute force implementation");
+ double max_hit_ratio = static_cast<double>(max_hits) / est_hits;
+ if (max_hit_ratio < _brute_force_limit) {
+ _use_brute_force = true;
+ LOG(debug, "too many hits filtered out, using brute force implementation");
+ } else {
+ est_hits = std::min(est_hits, max_hits);
}
- est_hits = std::min(est_hits, max_hits);
}
- est_hits = std::min(est_hits, _target_num_hits);
- setEstimate(HitEstimate(est_hits, false));
- perform_top_k();
- LOG(debug, "perform_top_k found %zu hits", _found_hits.size());
+ if (!_use_brute_force) {
+ est_hits = std::min(est_hits, _target_num_hits);
+ setEstimate(HitEstimate(est_hits, false));
+ perform_top_k();
+ LOG(debug, "perform_top_k found %zu hits", _found_hits.size());
+ }
}
}
diff --git a/searchlib/src/vespa/searchlib/queryeval/nearest_neighbor_blueprint.h b/searchlib/src/vespa/searchlib/queryeval/nearest_neighbor_blueprint.h
index 8656e5b4bf2..8050d350af5 100644
--- a/searchlib/src/vespa/searchlib/queryeval/nearest_neighbor_blueprint.h
+++ b/searchlib/src/vespa/searchlib/queryeval/nearest_neighbor_blueprint.h
@@ -23,6 +23,7 @@ private:
std::unique_ptr<vespalib::tensor::DenseTensorView> _query_tensor;
uint32_t _target_num_hits;
bool _approximate;
+ bool _use_brute_force;
uint32_t _explore_additional_hits;
double _brute_force_limit;
search::tensor::DistanceFunction::UP _fallback_dist_fun;
@@ -44,7 +45,7 @@ public:
const vespalib::tensor::DenseTensorView& get_query_tensor() const { return *_query_tensor; }
uint32_t get_target_num_hits() const { return _target_num_hits; }
void set_global_filter(const GlobalFilter &global_filter) override;
- bool may_approximate() const { return _approximate; }
+ bool may_approximate() const { return _approximate && !_use_brute_force; }
std::unique_ptr<SearchIterator> createLeafSearch(const search::fef::TermFieldMatchDataArray& tfmda,
bool strict) const override;