summaryrefslogtreecommitdiffstats
path: root/searchlib
diff options
context:
space:
mode:
authorArne Juul <arnej@verizonmedia.com>2020-05-11 10:48:47 +0000
committerArne Juul <arnej@verizonmedia.com>2020-05-12 14:29:42 +0000
commitaa958f364f43e58a0e8fff81b4e2e77513f22a7b (patch)
treee71edc0a26bc1cadcbcf3d83f736988051eefea5 /searchlib
parentb43e8090a860bf92b07153a2fd01d95f7fa2e548 (diff)
use global filter in NN blueprint
Diffstat (limited to 'searchlib')
-rw-r--r--searchlib/src/vespa/searchlib/queryeval/nearest_neighbor_blueprint.cpp34
-rw-r--r--searchlib/src/vespa/searchlib/queryeval/nearest_neighbor_blueprint.h1
2 files changed, 26 insertions, 9 deletions
diff --git a/searchlib/src/vespa/searchlib/queryeval/nearest_neighbor_blueprint.cpp b/searchlib/src/vespa/searchlib/queryeval/nearest_neighbor_blueprint.cpp
index b0db678dfc6..3ea515e5cd4 100644
--- a/searchlib/src/vespa/searchlib/queryeval/nearest_neighbor_blueprint.cpp
+++ b/searchlib/src/vespa/searchlib/queryeval/nearest_neighbor_blueprint.cpp
@@ -59,7 +59,8 @@ NearestNeighborBlueprint::NearestNeighborBlueprint(const queryeval::FieldSpec& f
_explore_additional_hits(explore_additional_hits),
_fallback_dist_fun(),
_distance_heap(target_num_hits),
- _found_hits()
+ _found_hits(),
+ _global_filter()
{
auto lct = _query_tensor->cellsRef().type;
auto rct = _attr_tensor.getTensorType().cell_type();
@@ -73,9 +74,6 @@ NearestNeighborBlueprint::NearestNeighborBlueprint(const queryeval::FieldSpec& f
_dist_fun = nns_index->distance_function();
}
uint32_t est_hits = _attr_tensor.getNumDocs();
- if (_approximate && nns_index) {
- est_hits = std::min(target_num_hits, est_hits);
- }
setEstimate(HitEstimate(est_hits, false));
set_want_global_filter(true);
}
@@ -85,8 +83,22 @@ NearestNeighborBlueprint::~NearestNeighborBlueprint() = default;
void
NearestNeighborBlueprint::set_global_filter(std::shared_ptr<BitVector> global_filter)
{
- // XXX do something with global_filter here
- (void) global_filter;
+ _global_filter = global_filter;
+ auto nns_index = _attr_tensor.nearest_neighbor_index();
+ if (_approximate && nns_index) {
+ uint32_t est_hits = _attr_tensor.getNumDocs();
+ if (_global_filter) {
+ uint32_t max_hits = _global_filter->countTrueBits();
+ if (max_hits * 10 < est_hits) {
+ // too many hits filtered out, use brute force implementation:
+ _approximate = false;
+ return;
+ }
+ est_hits = std::min(est_hits, max_hits);
+ }
+ est_hits = std::min(est_hits, _target_num_hits);
+ setEstimate(HitEstimate(est_hits, false));
+ }
}
void
@@ -96,11 +108,15 @@ NearestNeighborBlueprint::perform_top_k()
if (_approximate && nns_index) {
auto lhs_type = _query_tensor->fast_type();
auto rhs_type = _attr_tensor.getTensorType();
- // XXX deal with different cell types later
+ // different cell types should have be converted already
if (lhs_type == rhs_type) {
auto lhs = _query_tensor->cellsRef();
uint32_t k = _target_num_hits;
- _found_hits = nns_index->find_top_k(k, lhs, k + _explore_additional_hits);
+ if (_global_filter) {
+ _found_hits = nns_index->find_top_k_with_filter(k, lhs, *_global_filter, k + _explore_additional_hits);
+ } else {
+ _found_hits = nns_index->find_top_k(k, lhs, k + _explore_additional_hits);
+ }
}
}
}
@@ -122,7 +138,7 @@ NearestNeighborBlueprint::createLeafSearch(const search::fef::TermFieldMatchData
}
const vespalib::tensor::DenseTensorView &qT = *_query_tensor;
return NearestNeighborIterator::create(strict, tfmd, qT, _attr_tensor,
- _distance_heap, nullptr, _dist_fun);
+ _distance_heap, _global_filter.get(), _dist_fun);
}
void
diff --git a/searchlib/src/vespa/searchlib/queryeval/nearest_neighbor_blueprint.h b/searchlib/src/vespa/searchlib/queryeval/nearest_neighbor_blueprint.h
index cdb2b23e318..c1c6c28de37 100644
--- a/searchlib/src/vespa/searchlib/queryeval/nearest_neighbor_blueprint.h
+++ b/searchlib/src/vespa/searchlib/queryeval/nearest_neighbor_blueprint.h
@@ -28,6 +28,7 @@ private:
const search::tensor::DistanceFunction *_dist_fun;
mutable NearestNeighborDistanceHeap _distance_heap;
std::vector<search::tensor::NearestNeighborIndex::Neighbor> _found_hits;
+ std::shared_ptr<search::BitVector> _global_filter;
void perform_top_k();
public: