diff options
author | Geir Storli <geirst@yahooinc.com> | 2022-04-22 11:37:45 +0000 |
---|---|---|
committer | Geir Storli <geirst@yahooinc.com> | 2022-04-22 11:41:35 +0000 |
commit | 60338ad42fa4b5855fb58782d8819ea526f827c2 (patch) | |
tree | 701f558d4c887e04bc655e6f3b9d28db37fb7724 | |
parent | a54e8bacfd1d2be446fc89e16586acee94f79131 (diff) |
Improve and extend visit trace for nearest neighbor blueprint.
3 files changed, 20 insertions, 26 deletions
diff --git a/searchlib/src/tests/attribute/tensorattribute/tensorattribute_test.cpp b/searchlib/src/tests/attribute/tensorattribute/tensorattribute_test.cpp index 1d3305d2c1a..ec75a0d6d06 100644 --- a/searchlib/src/tests/attribute/tensorattribute/tensorattribute_test.cpp +++ b/searchlib/src/tests/attribute/tensorattribute/tensorattribute_test.cpp @@ -1041,7 +1041,6 @@ public: 100100.25, global_filter_lower_limit, 1.0); EXPECT_EQUAL(11u, bp->getState().estimate().estHits); - EXPECT_EQUAL(approximate, bp->may_approximate()); EXPECT_EQUAL(100100.25 * 100100.25, bp->get_distance_threshold()); return bp; } @@ -1068,7 +1067,6 @@ TEST_F("NN blueprint handles empty filter", NearestNeighborBlueprintFixture) auto empty_filter = GlobalFilter::create(); bp->set_global_filter(*empty_filter); EXPECT_EQUAL(3u, bp->getState().estimate().estHits); - EXPECT_TRUE(bp->may_approximate()); EXPECT_EQUAL(NNBA::INDEX_TOP_K, bp->get_algorithm()); } @@ -1081,7 +1079,6 @@ TEST_F("NN blueprint handles strong filter", NearestNeighborBlueprintFixture) auto strong_filter = GlobalFilter::create(std::move(filter)); bp->set_global_filter(*strong_filter); EXPECT_EQUAL(1u, bp->getState().estimate().estHits); - EXPECT_TRUE(bp->may_approximate()); EXPECT_EQUAL(NNBA::INDEX_TOP_K_WITH_FILTER, bp->get_algorithm()); } @@ -1099,7 +1096,6 @@ TEST_F("NN blueprint handles weak filter", NearestNeighborBlueprintFixture) auto weak_filter = GlobalFilter::create(std::move(filter)); bp->set_global_filter(*weak_filter); EXPECT_EQUAL(3u, bp->getState().estimate().estHits); - EXPECT_TRUE(bp->may_approximate()); EXPECT_EQUAL(NNBA::INDEX_TOP_K_WITH_FILTER, bp->get_algorithm()); } @@ -1112,7 +1108,6 @@ TEST_F("NN blueprint handles strong filter triggering brute force search", Neare auto strong_filter = GlobalFilter::create(std::move(filter)); bp->set_global_filter(*strong_filter); EXPECT_EQUAL(11u, bp->getState().estimate().estHits); - EXPECT_FALSE(bp->may_approximate()); EXPECT_EQUAL(NNBA::BRUTE_FORCE_FALLBACK, bp->get_algorithm()); } diff --git a/searchlib/src/vespa/searchlib/queryeval/nearest_neighbor_blueprint.cpp b/searchlib/src/vespa/searchlib/queryeval/nearest_neighbor_blueprint.cpp index bdcbb3db633..01d81f6398e 100644 --- a/searchlib/src/vespa/searchlib/queryeval/nearest_neighbor_blueprint.cpp +++ b/searchlib/src/vespa/searchlib/queryeval/nearest_neighbor_blueprint.cpp @@ -89,6 +89,7 @@ NearestNeighborBlueprint::NearestNeighborBlueprint(const queryeval::FieldSpec& f _found_hits(), _algorithm(Algorithm::BRUTE_FORCE), _global_filter(GlobalFilter::create()), + _global_filter_set(false), _global_filter_hits(), _global_filter_hit_ratio() { @@ -122,6 +123,7 @@ void NearestNeighborBlueprint::set_global_filter(const GlobalFilter &global_filter) { _global_filter = global_filter.shared_from_this(); + _global_filter_set = true; auto nns_index = _attr_tensor.nearest_neighbor_index(); LOG(debug, "set_global_filter with: %s / %s / %s", (_approximate ? "approximate" : "exact"), @@ -134,7 +136,6 @@ NearestNeighborBlueprint::set_global_filter(const GlobalFilter &global_filter) LOG(debug, "set_global_filter getNumDocs: %u / max_hits %u", est_hits, max_hits); double max_hit_ratio = static_cast<double>(max_hits) / est_hits; if (max_hit_ratio < _global_filter_lower_limit) { - _approximate = false; _algorithm = Algorithm::BRUTE_FORCE_FALLBACK; LOG(debug, "too many hits filtered out, using brute force implementation"); } else { @@ -143,30 +144,27 @@ NearestNeighborBlueprint::set_global_filter(const GlobalFilter &global_filter) _global_filter_hits = max_hits; _global_filter_hit_ratio = max_hit_ratio; } - if (_approximate) { + if (_algorithm != Algorithm::BRUTE_FORCE_FALLBACK) { est_hits = std::min(est_hits, _target_num_hits); setEstimate(HitEstimate(est_hits, false)); - perform_top_k(); + perform_top_k(nns_index); LOG(debug, "perform_top_k found %zu hits", _found_hits.size()); } } } void -NearestNeighborBlueprint::perform_top_k() +NearestNeighborBlueprint::perform_top_k(const search::tensor::NearestNeighborIndex* nns_index) { - auto nns_index = _attr_tensor.nearest_neighbor_index(); - if (_approximate && nns_index) { - auto lhs = _query_tensor->cells(); - uint32_t k = _target_num_hits; - if (_global_filter->has_filter()) { - auto filter = _global_filter->filter(); - _found_hits = nns_index->find_top_k_with_filter(k, lhs, *filter, k + _explore_additional_hits, _distance_threshold); - _algorithm = Algorithm::INDEX_TOP_K_WITH_FILTER; - } else { - _found_hits = nns_index->find_top_k(k, lhs, k + _explore_additional_hits, _distance_threshold); - _algorithm = Algorithm::INDEX_TOP_K; - } + auto lhs = _query_tensor->cells(); + uint32_t k = _target_num_hits; + if (_global_filter->has_filter()) { + auto filter = _global_filter->filter(); + _found_hits = nns_index->find_top_k_with_filter(k, lhs, *filter, k + _explore_additional_hits, _distance_threshold); + _algorithm = Algorithm::INDEX_TOP_K_WITH_FILTER; + } else { + _found_hits = nns_index->find_top_k(k, lhs, k + _explore_additional_hits, _distance_threshold); + _algorithm = Algorithm::INDEX_TOP_K; } } @@ -191,14 +189,15 @@ NearestNeighborBlueprint::visitMembers(vespalib::ObjectVisitor& visitor) const visitor.visitString("query_tensor", _query_tensor->type().to_spec()); visitor.visitInt("target_num_hits", _target_num_hits); visitor.visitInt("explore_additional_hits", _explore_additional_hits); - visitor.visitBool("approximate", _approximate); + visitor.visitBool("wanted_approximate", _approximate); visitor.visitBool("has_index", _attr_tensor.nearest_neighbor_index()); visitor.visitString("algorithm", to_string(_algorithm)); visitor.visitInt("top_k_hits", _found_hits.size()); visitor.openStruct("global_filter", "GlobalFilter"); - visitor.visitBool("is_set", (_global_filter != nullptr)); - visitor.visitBool("has_filter", (_global_filter && _global_filter->has_filter())); + visitor.visitBool("wanted", getState().want_global_filter()); + visitor.visitBool("set", _global_filter_set); + visitor.visitBool("calculated", _global_filter->has_filter()); visitor.visitFloat("lower_limit", _global_filter_lower_limit); visitor.visitFloat("upper_limit", _global_filter_upper_limit); if (_global_filter_hits.has_value()) { diff --git a/searchlib/src/vespa/searchlib/queryeval/nearest_neighbor_blueprint.h b/searchlib/src/vespa/searchlib/queryeval/nearest_neighbor_blueprint.h index 7922036dc42..7637c4dd6b7 100644 --- a/searchlib/src/vespa/searchlib/queryeval/nearest_neighbor_blueprint.h +++ b/searchlib/src/vespa/searchlib/queryeval/nearest_neighbor_blueprint.h @@ -41,10 +41,11 @@ private: std::vector<search::tensor::NearestNeighborIndex::Neighbor> _found_hits; Algorithm _algorithm; std::shared_ptr<const GlobalFilter> _global_filter; + bool _global_filter_set; std::optional<uint32_t> _global_filter_hits; std::optional<double> _global_filter_hit_ratio; - void perform_top_k(); + void perform_top_k(const search::tensor::NearestNeighborIndex* nns_index); public: NearestNeighborBlueprint(const queryeval::FieldSpec& field, const tensor::ITensorAttribute& attr_tensor, @@ -60,7 +61,6 @@ public: const vespalib::eval::Value& get_query_tensor() const { return *_query_tensor; } uint32_t get_target_num_hits() const { return _target_num_hits; } void set_global_filter(const GlobalFilter &global_filter) override; - bool may_approximate() const { return _approximate; } Algorithm get_algorithm() const { return _algorithm; } double get_distance_threshold() const { return _distance_threshold; } |