aboutsummaryrefslogtreecommitdiffstats
path: root/searchlib/src/vespa/searchlib/queryeval/nearest_neighbor_blueprint.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'searchlib/src/vespa/searchlib/queryeval/nearest_neighbor_blueprint.cpp')
-rw-r--r--searchlib/src/vespa/searchlib/queryeval/nearest_neighbor_blueprint.cpp44
1 files changed, 18 insertions, 26 deletions
diff --git a/searchlib/src/vespa/searchlib/queryeval/nearest_neighbor_blueprint.cpp b/searchlib/src/vespa/searchlib/queryeval/nearest_neighbor_blueprint.cpp
index bdcbb3db633..73eaa773c53 100644
--- a/searchlib/src/vespa/searchlib/queryeval/nearest_neighbor_blueprint.cpp
+++ b/searchlib/src/vespa/searchlib/queryeval/nearest_neighbor_blueprint.cpp
@@ -89,6 +89,7 @@ NearestNeighborBlueprint::NearestNeighborBlueprint(const queryeval::FieldSpec& f
_found_hits(),
_algorithm(Algorithm::BRUTE_FORCE),
_global_filter(GlobalFilter::create()),
+ _global_filter_set(false),
_global_filter_hits(),
_global_filter_hit_ratio()
{
@@ -122,51 +123,41 @@ void
NearestNeighborBlueprint::set_global_filter(const GlobalFilter &global_filter)
{
_global_filter = global_filter.shared_from_this();
+ _global_filter_set = true;
auto nns_index = _attr_tensor.nearest_neighbor_index();
- LOG(debug, "set_global_filter with: %s / %s / %s",
- (_approximate ? "approximate" : "exact"),
- (nns_index ? "nns_index" : "no_index"),
- (_global_filter->has_filter() ? "has_filter" : "no_filter"));
if (_approximate && nns_index) {
uint32_t est_hits = _attr_tensor.get_num_docs();
if (_global_filter->has_filter()) {
uint32_t max_hits = _global_filter->filter()->countTrueBits();
- LOG(debug, "set_global_filter getNumDocs: %u / max_hits %u", est_hits, max_hits);
double max_hit_ratio = static_cast<double>(max_hits) / est_hits;
if (max_hit_ratio < _global_filter_lower_limit) {
- _approximate = false;
_algorithm = Algorithm::BRUTE_FORCE_FALLBACK;
- LOG(debug, "too many hits filtered out, using brute force implementation");
} else {
est_hits = std::min(est_hits, max_hits);
}
_global_filter_hits = max_hits;
_global_filter_hit_ratio = max_hit_ratio;
}
- if (_approximate) {
+ if (_algorithm != Algorithm::BRUTE_FORCE_FALLBACK) {
est_hits = std::min(est_hits, _target_num_hits);
setEstimate(HitEstimate(est_hits, false));
- perform_top_k();
- LOG(debug, "perform_top_k found %zu hits", _found_hits.size());
+ perform_top_k(nns_index);
}
}
}
void
-NearestNeighborBlueprint::perform_top_k()
+NearestNeighborBlueprint::perform_top_k(const search::tensor::NearestNeighborIndex* nns_index)
{
- auto nns_index = _attr_tensor.nearest_neighbor_index();
- if (_approximate && nns_index) {
- auto lhs = _query_tensor->cells();
- uint32_t k = _target_num_hits;
- if (_global_filter->has_filter()) {
- auto filter = _global_filter->filter();
- _found_hits = nns_index->find_top_k_with_filter(k, lhs, *filter, k + _explore_additional_hits, _distance_threshold);
- _algorithm = Algorithm::INDEX_TOP_K_WITH_FILTER;
- } else {
- _found_hits = nns_index->find_top_k(k, lhs, k + _explore_additional_hits, _distance_threshold);
- _algorithm = Algorithm::INDEX_TOP_K;
- }
+ auto lhs = _query_tensor->cells();
+ uint32_t k = _target_num_hits;
+ if (_global_filter->has_filter()) {
+ auto filter = _global_filter->filter();
+ _found_hits = nns_index->find_top_k_with_filter(k, lhs, *filter, k + _explore_additional_hits, _distance_threshold);
+ _algorithm = Algorithm::INDEX_TOP_K_WITH_FILTER;
+ } else {
+ _found_hits = nns_index->find_top_k(k, lhs, k + _explore_additional_hits, _distance_threshold);
+ _algorithm = Algorithm::INDEX_TOP_K;
}
}
@@ -191,14 +182,15 @@ NearestNeighborBlueprint::visitMembers(vespalib::ObjectVisitor& visitor) const
visitor.visitString("query_tensor", _query_tensor->type().to_spec());
visitor.visitInt("target_num_hits", _target_num_hits);
visitor.visitInt("explore_additional_hits", _explore_additional_hits);
- visitor.visitBool("approximate", _approximate);
+ visitor.visitBool("wanted_approximate", _approximate);
visitor.visitBool("has_index", _attr_tensor.nearest_neighbor_index());
visitor.visitString("algorithm", to_string(_algorithm));
visitor.visitInt("top_k_hits", _found_hits.size());
visitor.openStruct("global_filter", "GlobalFilter");
- visitor.visitBool("is_set", (_global_filter != nullptr));
- visitor.visitBool("has_filter", (_global_filter && _global_filter->has_filter()));
+ visitor.visitBool("wanted", getState().want_global_filter());
+ visitor.visitBool("set", _global_filter_set);
+ visitor.visitBool("calculated", _global_filter->has_filter());
visitor.visitFloat("lower_limit", _global_filter_lower_limit);
visitor.visitFloat("upper_limit", _global_filter_upper_limit);
if (_global_filter_hits.has_value()) {