diff options
author | Arne Juul <arnej@verizonmedia.com> | 2020-05-11 10:48:47 +0000 |
---|---|---|
committer | Arne Juul <arnej@verizonmedia.com> | 2020-05-12 14:29:42 +0000 |
commit | aa958f364f43e58a0e8fff81b4e2e77513f22a7b (patch) | |
tree | e71edc0a26bc1cadcbcf3d83f736988051eefea5 /searchlib | |
parent | b43e8090a860bf92b07153a2fd01d95f7fa2e548 (diff) |
use global filter in NN blueprint
Diffstat (limited to 'searchlib')
-rw-r--r-- | searchlib/src/vespa/searchlib/queryeval/nearest_neighbor_blueprint.cpp | 34 | ||||
-rw-r--r-- | searchlib/src/vespa/searchlib/queryeval/nearest_neighbor_blueprint.h | 1 |
2 files changed, 26 insertions, 9 deletions
diff --git a/searchlib/src/vespa/searchlib/queryeval/nearest_neighbor_blueprint.cpp b/searchlib/src/vespa/searchlib/queryeval/nearest_neighbor_blueprint.cpp index b0db678dfc6..3ea515e5cd4 100644 --- a/searchlib/src/vespa/searchlib/queryeval/nearest_neighbor_blueprint.cpp +++ b/searchlib/src/vespa/searchlib/queryeval/nearest_neighbor_blueprint.cpp @@ -59,7 +59,8 @@ NearestNeighborBlueprint::NearestNeighborBlueprint(const queryeval::FieldSpec& f _explore_additional_hits(explore_additional_hits), _fallback_dist_fun(), _distance_heap(target_num_hits), - _found_hits() + _found_hits(), + _global_filter() { auto lct = _query_tensor->cellsRef().type; auto rct = _attr_tensor.getTensorType().cell_type(); @@ -73,9 +74,6 @@ NearestNeighborBlueprint::NearestNeighborBlueprint(const queryeval::FieldSpec& f _dist_fun = nns_index->distance_function(); } uint32_t est_hits = _attr_tensor.getNumDocs(); - if (_approximate && nns_index) { - est_hits = std::min(target_num_hits, est_hits); - } setEstimate(HitEstimate(est_hits, false)); set_want_global_filter(true); } @@ -85,8 +83,22 @@ NearestNeighborBlueprint::~NearestNeighborBlueprint() = default; void NearestNeighborBlueprint::set_global_filter(std::shared_ptr<BitVector> global_filter) { - // XXX do something with global_filter here - (void) global_filter; + _global_filter = global_filter; + auto nns_index = _attr_tensor.nearest_neighbor_index(); + if (_approximate && nns_index) { + uint32_t est_hits = _attr_tensor.getNumDocs(); + if (_global_filter) { + uint32_t max_hits = _global_filter->countTrueBits(); + if (max_hits * 10 < est_hits) { + // too many hits filtered out, use brute force implementation: + _approximate = false; + return; + } + est_hits = std::min(est_hits, max_hits); + } + est_hits = std::min(est_hits, _target_num_hits); + setEstimate(HitEstimate(est_hits, false)); + } } void @@ -96,11 +108,15 @@ NearestNeighborBlueprint::perform_top_k() if (_approximate && nns_index) { auto lhs_type = _query_tensor->fast_type(); auto rhs_type = _attr_tensor.getTensorType(); - // XXX deal with different cell types later + // different cell types should have be converted already if (lhs_type == rhs_type) { auto lhs = _query_tensor->cellsRef(); uint32_t k = _target_num_hits; - _found_hits = nns_index->find_top_k(k, lhs, k + _explore_additional_hits); + if (_global_filter) { + _found_hits = nns_index->find_top_k_with_filter(k, lhs, *_global_filter, k + _explore_additional_hits); + } else { + _found_hits = nns_index->find_top_k(k, lhs, k + _explore_additional_hits); + } } } } @@ -122,7 +138,7 @@ NearestNeighborBlueprint::createLeafSearch(const search::fef::TermFieldMatchData } const vespalib::tensor::DenseTensorView &qT = *_query_tensor; return NearestNeighborIterator::create(strict, tfmd, qT, _attr_tensor, - _distance_heap, nullptr, _dist_fun); + _distance_heap, _global_filter.get(), _dist_fun); } void diff --git a/searchlib/src/vespa/searchlib/queryeval/nearest_neighbor_blueprint.h b/searchlib/src/vespa/searchlib/queryeval/nearest_neighbor_blueprint.h index cdb2b23e318..c1c6c28de37 100644 --- a/searchlib/src/vespa/searchlib/queryeval/nearest_neighbor_blueprint.h +++ b/searchlib/src/vespa/searchlib/queryeval/nearest_neighbor_blueprint.h @@ -28,6 +28,7 @@ private: const search::tensor::DistanceFunction *_dist_fun; mutable NearestNeighborDistanceHeap _distance_heap; std::vector<search::tensor::NearestNeighborIndex::Neighbor> _found_hits; + std::shared_ptr<search::BitVector> _global_filter; void perform_top_k(); public: |