summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorGeir Storli <geirst@yahooinc.com>2022-04-22 11:37:45 +0000
committerGeir Storli <geirst@yahooinc.com>2022-04-22 11:41:35 +0000
commit60338ad42fa4b5855fb58782d8819ea526f827c2 (patch)
tree701f558d4c887e04bc655e6f3b9d28db37fb7724
parenta54e8bacfd1d2be446fc89e16586acee94f79131 (diff)
Improve and extend visit trace for nearest neighbor blueprint.
-rw-r--r--searchlib/src/tests/attribute/tensorattribute/tensorattribute_test.cpp5
-rw-r--r--searchlib/src/vespa/searchlib/queryeval/nearest_neighbor_blueprint.cpp37
-rw-r--r--searchlib/src/vespa/searchlib/queryeval/nearest_neighbor_blueprint.h4
3 files changed, 20 insertions, 26 deletions
diff --git a/searchlib/src/tests/attribute/tensorattribute/tensorattribute_test.cpp b/searchlib/src/tests/attribute/tensorattribute/tensorattribute_test.cpp
index 1d3305d2c1a..ec75a0d6d06 100644
--- a/searchlib/src/tests/attribute/tensorattribute/tensorattribute_test.cpp
+++ b/searchlib/src/tests/attribute/tensorattribute/tensorattribute_test.cpp
@@ -1041,7 +1041,6 @@ public:
100100.25,
global_filter_lower_limit, 1.0);
EXPECT_EQUAL(11u, bp->getState().estimate().estHits);
- EXPECT_EQUAL(approximate, bp->may_approximate());
EXPECT_EQUAL(100100.25 * 100100.25, bp->get_distance_threshold());
return bp;
}
@@ -1068,7 +1067,6 @@ TEST_F("NN blueprint handles empty filter", NearestNeighborBlueprintFixture)
auto empty_filter = GlobalFilter::create();
bp->set_global_filter(*empty_filter);
EXPECT_EQUAL(3u, bp->getState().estimate().estHits);
- EXPECT_TRUE(bp->may_approximate());
EXPECT_EQUAL(NNBA::INDEX_TOP_K, bp->get_algorithm());
}
@@ -1081,7 +1079,6 @@ TEST_F("NN blueprint handles strong filter", NearestNeighborBlueprintFixture)
auto strong_filter = GlobalFilter::create(std::move(filter));
bp->set_global_filter(*strong_filter);
EXPECT_EQUAL(1u, bp->getState().estimate().estHits);
- EXPECT_TRUE(bp->may_approximate());
EXPECT_EQUAL(NNBA::INDEX_TOP_K_WITH_FILTER, bp->get_algorithm());
}
@@ -1099,7 +1096,6 @@ TEST_F("NN blueprint handles weak filter", NearestNeighborBlueprintFixture)
auto weak_filter = GlobalFilter::create(std::move(filter));
bp->set_global_filter(*weak_filter);
EXPECT_EQUAL(3u, bp->getState().estimate().estHits);
- EXPECT_TRUE(bp->may_approximate());
EXPECT_EQUAL(NNBA::INDEX_TOP_K_WITH_FILTER, bp->get_algorithm());
}
@@ -1112,7 +1108,6 @@ TEST_F("NN blueprint handles strong filter triggering brute force search", Neare
auto strong_filter = GlobalFilter::create(std::move(filter));
bp->set_global_filter(*strong_filter);
EXPECT_EQUAL(11u, bp->getState().estimate().estHits);
- EXPECT_FALSE(bp->may_approximate());
EXPECT_EQUAL(NNBA::BRUTE_FORCE_FALLBACK, bp->get_algorithm());
}
diff --git a/searchlib/src/vespa/searchlib/queryeval/nearest_neighbor_blueprint.cpp b/searchlib/src/vespa/searchlib/queryeval/nearest_neighbor_blueprint.cpp
index bdcbb3db633..01d81f6398e 100644
--- a/searchlib/src/vespa/searchlib/queryeval/nearest_neighbor_blueprint.cpp
+++ b/searchlib/src/vespa/searchlib/queryeval/nearest_neighbor_blueprint.cpp
@@ -89,6 +89,7 @@ NearestNeighborBlueprint::NearestNeighborBlueprint(const queryeval::FieldSpec& f
_found_hits(),
_algorithm(Algorithm::BRUTE_FORCE),
_global_filter(GlobalFilter::create()),
+ _global_filter_set(false),
_global_filter_hits(),
_global_filter_hit_ratio()
{
@@ -122,6 +123,7 @@ void
NearestNeighborBlueprint::set_global_filter(const GlobalFilter &global_filter)
{
_global_filter = global_filter.shared_from_this();
+ _global_filter_set = true;
auto nns_index = _attr_tensor.nearest_neighbor_index();
LOG(debug, "set_global_filter with: %s / %s / %s",
(_approximate ? "approximate" : "exact"),
@@ -134,7 +136,6 @@ NearestNeighborBlueprint::set_global_filter(const GlobalFilter &global_filter)
LOG(debug, "set_global_filter getNumDocs: %u / max_hits %u", est_hits, max_hits);
double max_hit_ratio = static_cast<double>(max_hits) / est_hits;
if (max_hit_ratio < _global_filter_lower_limit) {
- _approximate = false;
_algorithm = Algorithm::BRUTE_FORCE_FALLBACK;
LOG(debug, "too many hits filtered out, using brute force implementation");
} else {
@@ -143,30 +144,27 @@ NearestNeighborBlueprint::set_global_filter(const GlobalFilter &global_filter)
_global_filter_hits = max_hits;
_global_filter_hit_ratio = max_hit_ratio;
}
- if (_approximate) {
+ if (_algorithm != Algorithm::BRUTE_FORCE_FALLBACK) {
est_hits = std::min(est_hits, _target_num_hits);
setEstimate(HitEstimate(est_hits, false));
- perform_top_k();
+ perform_top_k(nns_index);
LOG(debug, "perform_top_k found %zu hits", _found_hits.size());
}
}
}
void
-NearestNeighborBlueprint::perform_top_k()
+NearestNeighborBlueprint::perform_top_k(const search::tensor::NearestNeighborIndex* nns_index)
{
- auto nns_index = _attr_tensor.nearest_neighbor_index();
- if (_approximate && nns_index) {
- auto lhs = _query_tensor->cells();
- uint32_t k = _target_num_hits;
- if (_global_filter->has_filter()) {
- auto filter = _global_filter->filter();
- _found_hits = nns_index->find_top_k_with_filter(k, lhs, *filter, k + _explore_additional_hits, _distance_threshold);
- _algorithm = Algorithm::INDEX_TOP_K_WITH_FILTER;
- } else {
- _found_hits = nns_index->find_top_k(k, lhs, k + _explore_additional_hits, _distance_threshold);
- _algorithm = Algorithm::INDEX_TOP_K;
- }
+ auto lhs = _query_tensor->cells();
+ uint32_t k = _target_num_hits;
+ if (_global_filter->has_filter()) {
+ auto filter = _global_filter->filter();
+ _found_hits = nns_index->find_top_k_with_filter(k, lhs, *filter, k + _explore_additional_hits, _distance_threshold);
+ _algorithm = Algorithm::INDEX_TOP_K_WITH_FILTER;
+ } else {
+ _found_hits = nns_index->find_top_k(k, lhs, k + _explore_additional_hits, _distance_threshold);
+ _algorithm = Algorithm::INDEX_TOP_K;
}
}
@@ -191,14 +189,15 @@ NearestNeighborBlueprint::visitMembers(vespalib::ObjectVisitor& visitor) const
visitor.visitString("query_tensor", _query_tensor->type().to_spec());
visitor.visitInt("target_num_hits", _target_num_hits);
visitor.visitInt("explore_additional_hits", _explore_additional_hits);
- visitor.visitBool("approximate", _approximate);
+ visitor.visitBool("wanted_approximate", _approximate);
visitor.visitBool("has_index", _attr_tensor.nearest_neighbor_index());
visitor.visitString("algorithm", to_string(_algorithm));
visitor.visitInt("top_k_hits", _found_hits.size());
visitor.openStruct("global_filter", "GlobalFilter");
- visitor.visitBool("is_set", (_global_filter != nullptr));
- visitor.visitBool("has_filter", (_global_filter && _global_filter->has_filter()));
+ visitor.visitBool("wanted", getState().want_global_filter());
+ visitor.visitBool("set", _global_filter_set);
+ visitor.visitBool("calculated", _global_filter->has_filter());
visitor.visitFloat("lower_limit", _global_filter_lower_limit);
visitor.visitFloat("upper_limit", _global_filter_upper_limit);
if (_global_filter_hits.has_value()) {
diff --git a/searchlib/src/vespa/searchlib/queryeval/nearest_neighbor_blueprint.h b/searchlib/src/vespa/searchlib/queryeval/nearest_neighbor_blueprint.h
index 7922036dc42..7637c4dd6b7 100644
--- a/searchlib/src/vespa/searchlib/queryeval/nearest_neighbor_blueprint.h
+++ b/searchlib/src/vespa/searchlib/queryeval/nearest_neighbor_blueprint.h
@@ -41,10 +41,11 @@ private:
std::vector<search::tensor::NearestNeighborIndex::Neighbor> _found_hits;
Algorithm _algorithm;
std::shared_ptr<const GlobalFilter> _global_filter;
+ bool _global_filter_set;
std::optional<uint32_t> _global_filter_hits;
std::optional<double> _global_filter_hit_ratio;
- void perform_top_k();
+ void perform_top_k(const search::tensor::NearestNeighborIndex* nns_index);
public:
NearestNeighborBlueprint(const queryeval::FieldSpec& field,
const tensor::ITensorAttribute& attr_tensor,
@@ -60,7 +61,6 @@ public:
const vespalib::eval::Value& get_query_tensor() const { return *_query_tensor; }
uint32_t get_target_num_hits() const { return _target_num_hits; }
void set_global_filter(const GlobalFilter &global_filter) override;
- bool may_approximate() const { return _approximate; }
Algorithm get_algorithm() const { return _algorithm; }
double get_distance_threshold() const { return _distance_threshold; }