diff options
author | Arne Juul <arnej@verizonmedia.com> | 2021-01-07 12:42:17 +0000 |
---|---|---|
committer | Arne Juul <arnej@verizonmedia.com> | 2021-01-08 10:56:14 +0000 |
commit | cf199f338efafad8c0af7de48094bd3d0037b96a (patch) | |
tree | 2c63ecee902b297006edb41a67925e721bfddff6 /searchlib/src | |
parent | 8aa9ffda4324ddd5baff87be858063c6399a26ca (diff) |
add distanceThreshold option for nearestNeighbor operator
Diffstat (limited to 'searchlib/src')
22 files changed, 153 insertions, 39 deletions
diff --git a/searchlib/src/tests/attribute/searchable/attributeblueprint_test.cpp b/searchlib/src/tests/attribute/searchable/attributeblueprint_test.cpp index 51b4f1d760d..855510d0457 100644 --- a/searchlib/src/tests/attribute/searchable/attributeblueprint_test.cpp +++ b/searchlib/src/tests/attribute/searchable/attributeblueprint_test.cpp @@ -341,7 +341,7 @@ public: request_ctx.set_query_tensor("query_tensor", tensor_spec); } Blueprint::UP create_blueprint() { - query::NearestNeighborTerm term("query_tensor", attr_name, 0, Weight(0), 7, true, 33); + query::NearestNeighborTerm term("query_tensor", attr_name, 0, Weight(0), 7, true, 33, 100100.25); return BlueprintFactoryFixture::create_blueprint(term); } }; diff --git a/searchlib/src/tests/attribute/tensorattribute/tensorattribute_test.cpp b/searchlib/src/tests/attribute/tensorattribute/tensorattribute_test.cpp index 7b597af417d..cbdb2c9bd22 100644 --- a/searchlib/src/tests/attribute/tensorattribute/tensorattribute_test.cpp +++ b/searchlib/src/tests/attribute/tensorattribute/tensorattribute_test.cpp @@ -13,6 +13,7 @@ #include <vespa/searchlib/tensor/dense_tensor_attribute.h> #include <vespa/searchlib/tensor/direct_tensor_attribute.h> #include <vespa/searchlib/tensor/doc_vector_access.h> +#include <vespa/searchlib/tensor/distance_functions.h> #include <vespa/searchlib/tensor/hnsw_index.h> #include <vespa/searchlib/tensor/nearest_neighbor_index.h> #include <vespa/searchlib/tensor/nearest_neighbor_index_factory.h> @@ -206,24 +207,32 @@ public: _index_value = (reinterpret_cast<const int*>(buf.buffer()))[0]; return true; } - std::vector<Neighbor> find_top_k(uint32_t k, vespalib::eval::TypedCells vector, uint32_t explore_k) const override { + std::vector<Neighbor> find_top_k(uint32_t k, vespalib::eval::TypedCells vector, uint32_t explore_k, + double distance_threshold) const override + { (void) k; (void) vector; (void) explore_k; + (void) distance_threshold; return std::vector<Neighbor>(); } std::vector<Neighbor> find_top_k_with_filter(uint32_t k, vespalib::eval::TypedCells vector, - const search::BitVector& filter, uint32_t explore_k) const override + const search::BitVector& filter, uint32_t explore_k, + double distance_threshold) const override { (void) k; (void) vector; (void) explore_k; (void) filter; + (void) distance_threshold; return std::vector<Neighbor>(); } - const search::tensor::DistanceFunction *distance_function() const override { return nullptr; } + const search::tensor::DistanceFunction *distance_function() const override { + static search::tensor::SquaredEuclideanDistance<double> my_dist_fun; + return &my_dist_fun; + } }; class MockNearestNeighborIndexFactory : public NearestNeighborIndexFactory { @@ -914,9 +923,12 @@ public: field, as_dense_tensor(), createDenseTensor(vec_2d(17, 42)), - 3, true, 5, brute_force_limit); + 3, true, 5, + 100100.25, + brute_force_limit); EXPECT_EQUAL(11u, bp->getState().estimate().estHits); EXPECT_TRUE(bp->may_approximate()); + EXPECT_EQUAL(100100.25 * 100100.25, bp->get_distance_threshold()); return bp; } }; diff --git a/searchlib/src/tests/query/query_visitor_test.cpp b/searchlib/src/tests/query/query_visitor_test.cpp index 8441dc2227f..946ad17352d 100644 --- a/searchlib/src/tests/query/query_visitor_test.cpp +++ b/searchlib/src/tests/query/query_visitor_test.cpp @@ -99,7 +99,7 @@ void Test::requireThatAllNodesCanBeVisited() { checkVisit<SuffixTerm>(new SimpleSuffixTerm("t", "field", 0, Weight(0))); checkVisit<PredicateQuery>(new SimplePredicateQuery(PredicateQueryTerm::UP(), "field", 0, Weight(0))); checkVisit<RegExpTerm>(new SimpleRegExpTerm("t", "field", 0, Weight(0))); - checkVisit<NearestNeighborTerm>(new SimpleNearestNeighborTerm("query_tensor", "doc_tensor", 0, Weight(0), 123, true, 321)); + checkVisit<NearestNeighborTerm>(new SimpleNearestNeighborTerm("query_tensor", "doc_tensor", 0, Weight(0), 123, true, 321, 100100.25)); } } // namespace diff --git a/searchlib/src/tests/query/querybuilder_test.cpp b/searchlib/src/tests/query/querybuilder_test.cpp index 5a5a5eafb2c..30b4d2ae264 100644 --- a/searchlib/src/tests/query/querybuilder_test.cpp +++ b/searchlib/src/tests/query/querybuilder_test.cpp @@ -110,7 +110,7 @@ Node::UP createQueryTree() { builder.addStringTerm(str[5], view[5], id[5], weight[6]); builder.addStringTerm(str[6], view[6], id[6], weight[7]); } - builder.add_nearest_neighbor_term("query_tensor", "doc_tensor", id[3], weight[5], 7, true, 33); + builder.add_nearest_neighbor_term("query_tensor", "doc_tensor", id[3], weight[5], 7, true, 33, 100100.25); } Node::UP node = builder.build(); ASSERT_TRUE(node.get()); @@ -395,8 +395,9 @@ struct MyRegExpTerm : RegExpTerm { struct MyNearestNeighborTerm : NearestNeighborTerm { MyNearestNeighborTerm(vespalib::stringref query_tensor_name, vespalib::stringref field_name, int32_t i, Weight w, uint32_t target_num_hits, - bool allow_approximate, uint32_t explore_additional_hits) - : NearestNeighborTerm(query_tensor_name, field_name, i, w, target_num_hits, allow_approximate, explore_additional_hits) + bool allow_approximate, uint32_t explore_additional_hits, + double distance_threshold) + : NearestNeighborTerm(query_tensor_name, field_name, i, w, target_num_hits, allow_approximate, explore_additional_hits, distance_threshold) {} }; diff --git a/searchlib/src/tests/queryeval/nearest_neighbor/nearest_neighbor_test.cpp b/searchlib/src/tests/queryeval/nearest_neighbor/nearest_neighbor_test.cpp index ad450a91f33..09790e7e360 100644 --- a/searchlib/src/tests/queryeval/nearest_neighbor/nearest_neighbor_test.cpp +++ b/searchlib/src/tests/queryeval/nearest_neighbor/nearest_neighbor_test.cpp @@ -121,11 +121,12 @@ struct Fixture }; template <bool strict> -SimpleResult find_matches(Fixture &env, const Value &qtv) { +SimpleResult find_matches(Fixture &env, const Value &qtv, double threshold = std::numeric_limits<double>::max()) { auto md = MatchData::makeTestInstance(2, 2); auto &tfmd = *(md->resolveTermField(0)); auto &attr = *(env._tensorAttr); NearestNeighborDistanceHeap dh(2); + dh.set_distance_threshold(env.dist_fun()->convert_threshold(threshold)); const BitVector *filter = env._global_filter.get(); auto search = NearestNeighborIterator::create(strict, tfmd, qtv, attr, dh, filter, env.dist_fun()); if (strict) { @@ -159,6 +160,19 @@ verify_iterator_returns_expected_results(const vespalib::string& attribute_tenso EXPECT_EQUAL(result, farExpect); result = find_matches<false>(fixture, *farTensor); EXPECT_EQUAL(result, farExpect); + + SimpleResult null_thr5_exp({1,4,6}); + result = find_matches<true>(fixture, *nullTensor, 5.0); + EXPECT_EQUAL(result, null_thr5_exp); + result = find_matches<false>(fixture, *nullTensor, 5.0); + EXPECT_EQUAL(result, null_thr5_exp); + + SimpleResult far_thr4_exp({2,5}); + result = find_matches<true>(fixture, *farTensor, 4.0); + EXPECT_EQUAL(result, far_thr4_exp); + result = find_matches<false>(fixture, *farTensor, 4.0); + EXPECT_EQUAL(result, far_thr4_exp); + } TEST("require that NearestNeighborIterator returns expected results") { diff --git a/searchlib/src/tests/tensor/distance_functions/distance_functions_test.cpp b/searchlib/src/tests/tensor/distance_functions/distance_functions_test.cpp index 06fb95089fd..ee0a2aff80e 100644 --- a/searchlib/src/tests/tensor/distance_functions/distance_functions_test.cpp +++ b/searchlib/src/tests/tensor/distance_functions/distance_functions_test.cpp @@ -24,10 +24,12 @@ void verify_geo_miles(const DistanceFunction *dist_fun, TypedCells t2(p2); double abstract_distance = dist_fun->calc(t1, t2); double raw_score = dist_fun->to_rawscore(abstract_distance); - double m = ((1.0/raw_score)-1.0); - double d_miles = m / 1.609344; + double km = ((1.0/raw_score)-1.0); + double d_miles = km / 1.609344; EXPECT_GE(d_miles, exp_miles*0.99); EXPECT_LE(d_miles, exp_miles*1.01); + double threshold = dist_fun->convert_threshold(km); + EXPECT_DOUBLE_EQ(threshold, abstract_distance); } @@ -50,6 +52,10 @@ TEST(DistanceFunctionsTest, euclidean_gives_expected_score) double d12 = euclid->calc(t(p1), t(p2)); EXPECT_EQ(d12, 2.0); EXPECT_DOUBLE_EQ(euclid->to_rawscore(d12), 1.0/(1.0 + sqrt(2.0))); + double threshold = euclid->convert_threshold(8.0); + EXPECT_EQ(threshold, 64.0); + threshold = euclid->convert_threshold(0.5); + EXPECT_EQ(threshold, 0.25); } TEST(DistanceFunctionsTest, angular_gives_expected_score) @@ -75,19 +81,28 @@ TEST(DistanceFunctionsTest, angular_gives_expected_score) EXPECT_DOUBLE_EQ(a23, 1.0); EXPECT_FLOAT_EQ(angular->to_rawscore(a12), 1.0/(1.0 + pi/2)); + double threshold = angular->convert_threshold(pi/2); + EXPECT_DOUBLE_EQ(threshold, 1.0); + double a14 = angular->calc(t(p1), t(p4)); double a24 = angular->calc(t(p2), t(p4)); EXPECT_FLOAT_EQ(a14, 0.5); EXPECT_FLOAT_EQ(a24, 0.5); EXPECT_FLOAT_EQ(angular->to_rawscore(a14), 1.0/(1.0 + pi/3)); + threshold = angular->convert_threshold(pi/3); + EXPECT_DOUBLE_EQ(threshold, 0.5); double a34 = angular->calc(t(p3), t(p4)); EXPECT_FLOAT_EQ(a34, (1.0 - 0.707107)); EXPECT_FLOAT_EQ(angular->to_rawscore(a34), 1.0/(1.0 + pi/4)); + threshold = angular->convert_threshold(pi/4); + EXPECT_FLOAT_EQ(threshold, a34); double a25 = angular->calc(t(p2), t(p5)); EXPECT_DOUBLE_EQ(a25, 2.0); EXPECT_FLOAT_EQ(angular->to_rawscore(a25), 1.0/(1.0 + pi)); + threshold = angular->convert_threshold(pi); + EXPECT_FLOAT_EQ(threshold, 2.0); double a44 = angular->calc(t(p4), t(p4)); EXPECT_GE(a44, 0.0); @@ -98,6 +113,8 @@ TEST(DistanceFunctionsTest, angular_gives_expected_score) EXPECT_GE(a66, 0.0); EXPECT_LT(a66, 0.000001); EXPECT_FLOAT_EQ(angular->to_rawscore(a66), 1.0); + threshold = angular->convert_threshold(0.0); + EXPECT_FLOAT_EQ(threshold, 0.0); double a16 = angular->calc(t(p1), t(p6)); double a26 = angular->calc(t(p2), t(p6)); @@ -127,6 +144,7 @@ TEST(DistanceFunctionsTest, innerproduct_gives_expected_score) EXPECT_DOUBLE_EQ(i12, 1.0); EXPECT_DOUBLE_EQ(i13, 1.0); EXPECT_DOUBLE_EQ(i23, 1.0); + double i14 = innerproduct->calc(t(p1), t(p4)); double i24 = innerproduct->calc(t(p2), t(p4)); EXPECT_DOUBLE_EQ(i14, 0.5); @@ -140,6 +158,13 @@ TEST(DistanceFunctionsTest, innerproduct_gives_expected_score) double i44 = innerproduct->calc(t(p4), t(p4)); EXPECT_GE(i44, 0.0); EXPECT_LT(i44, 0.000001); + + double threshold = innerproduct->convert_threshold(0.25); + EXPECT_DOUBLE_EQ(threshold, 0.25); + threshold = innerproduct->convert_threshold(0.5); + EXPECT_DOUBLE_EQ(threshold, 0.5); + threshold = innerproduct->convert_threshold(1.0); + EXPECT_DOUBLE_EQ(threshold, 1.0); } TEST(DistanceFunctionsTest, hamming_gives_expected_score) @@ -180,6 +205,13 @@ TEST(DistanceFunctionsTest, hamming_gives_expected_score) double d25 = hamming->calc(t(points[2]), t(points[5])); EXPECT_EQ(d25, 1.0); EXPECT_DOUBLE_EQ(hamming->to_rawscore(d25), 1.0/(1.0 + 1.0)); + + double threshold = hamming->convert_threshold(0.25); + EXPECT_DOUBLE_EQ(threshold, 0.25); + threshold = hamming->convert_threshold(0.5); + EXPECT_DOUBLE_EQ(threshold, 0.5); + threshold = hamming->convert_threshold(1.0); + EXPECT_DOUBLE_EQ(threshold, 1.0); } TEST(GeoDegreesTest, gives_expected_score) diff --git a/searchlib/src/tests/tensor/hnsw_index/hnsw_index_test.cpp b/searchlib/src/tests/tensor/hnsw_index/hnsw_index_test.cpp index acc157709c0..d081c299a43 100644 --- a/searchlib/src/tests/tensor/hnsw_index/hnsw_index_test.cpp +++ b/searchlib/src/tests/tensor/hnsw_index/hnsw_index_test.cpp @@ -144,11 +144,20 @@ public: if (exp_hits.size() == k) { std::vector<uint32_t> expected_by_docid = exp_hits; std::sort(expected_by_docid.begin(), expected_by_docid.end()); - auto got_by_docid = index->find_top_k(k, qv, k); + auto got_by_docid = index->find_top_k(k, qv, k, 100100.25); for (idx = 0; idx < k; ++idx) { EXPECT_EQ(expected_by_docid[idx], got_by_docid[idx].docid); } } + if ((rv.size() > 1) && (rv[0].distance < rv[1].distance)) { + double thr = (rv[0].distance + rv[1].distance) * 0.5; + auto got_by_docid = index->find_top_k_with_filter(k, qv, *global_filter, k, thr); + for (const auto & hit : got_by_docid) { + printf("hit docid=%u dist=%g (thr %g)\n", hit.docid, hit.distance, thr); + } + EXPECT_EQ(got_by_docid.size(), 1); + EXPECT_EQ(got_by_docid[0].docid, rv[0].docid); + } } }; diff --git a/searchlib/src/vespa/searchlib/attribute/attribute_blueprint_factory.cpp b/searchlib/src/vespa/searchlib/attribute/attribute_blueprint_factory.cpp index dfcbfbbbe2b..70a59f1575a 100644 --- a/searchlib/src/vespa/searchlib/attribute/attribute_blueprint_factory.cpp +++ b/searchlib/src/vespa/searchlib/attribute/attribute_blueprint_factory.cpp @@ -736,6 +736,7 @@ public: n.get_target_num_hits(), n.get_allow_approximate(), n.get_explore_additional_hits(), + n.get_distance_threshold(), getRequestContext().get_attribute_blueprint_params().nearest_neighbor_brute_force_limit)); } }; diff --git a/searchlib/src/vespa/searchlib/parsequery/stackdumpiterator.cpp b/searchlib/src/vespa/searchlib/parsequery/stackdumpiterator.cpp index 6039a86580c..6feec9bbba2 100644 --- a/searchlib/src/vespa/searchlib/parsequery/stackdumpiterator.cpp +++ b/searchlib/src/vespa/searchlib/parsequery/stackdumpiterator.cpp @@ -273,6 +273,7 @@ SimpleQueryStackDumpIterator::next() _extraIntArg1 = readCompressedPositiveInt(p); // targetNumHits _extraIntArg2 = readCompressedPositiveInt(p); // allow_approximate _extraIntArg3 = readCompressedPositiveInt(p); // explore_additional_hits + _extraDoubleArg4 = read_double(p); // distance threshold _currArity = 0; } catch (...) { return false; diff --git a/searchlib/src/vespa/searchlib/parsequery/stackdumpiterator.h b/searchlib/src/vespa/searchlib/parsequery/stackdumpiterator.h index d60765f3fe1..301929c8919 100644 --- a/searchlib/src/vespa/searchlib/parsequery/stackdumpiterator.h +++ b/searchlib/src/vespa/searchlib/parsequery/stackdumpiterator.h @@ -119,6 +119,7 @@ public: uint32_t getNearDistance() const { return _extraIntArg1; } uint32_t getTargetNumHits() const { return _extraIntArg1; } + double getDistanceThreshold() const { return _extraDoubleArg4; } double getScoreThreshold() const { return _extraDoubleArg4; } double getThresholdBoostFactor() const { return _extraDoubleArg5; } bool getAllowApproximate() const { return (_extraIntArg2 != 0); } diff --git a/searchlib/src/vespa/searchlib/query/tree/querybuilder.h b/searchlib/src/vespa/searchlib/query/tree/querybuilder.h index 8e6f2944ec9..8392730cd29 100644 --- a/searchlib/src/vespa/searchlib/query/tree/querybuilder.h +++ b/searchlib/src/vespa/searchlib/query/tree/querybuilder.h @@ -206,10 +206,12 @@ template <class NodeTypes> typename NodeTypes::NearestNeighborTerm * create_nearest_neighbor_term(vespalib::stringref query_tensor_name, vespalib::stringref field_name, int32_t id, Weight weight, uint32_t target_num_hits, - bool allow_approximate, uint32_t explore_additional_hits) + bool allow_approximate, uint32_t explore_additional_hits, + double distance_threshold) { return new typename NodeTypes::NearestNeighborTerm(query_tensor_name, field_name, id, weight, - target_num_hits, allow_approximate, explore_additional_hits); + target_num_hits, allow_approximate, explore_additional_hits, + distance_threshold); } template <class NodeTypes> @@ -321,9 +323,11 @@ public: } typename NodeTypes::NearestNeighborTerm &add_nearest_neighbor_term(stringref query_tensor_name, stringref field_name, int32_t id, Weight weight, uint32_t target_num_hits, - bool allow_approximate, uint32_t explore_additional_hits) { + bool allow_approximate, uint32_t explore_additional_hits, + double distance_threshold) + { adjustWeight(weight); - return addTerm(create_nearest_neighbor_term<NodeTypes>(query_tensor_name, field_name, id, weight, target_num_hits, allow_approximate, explore_additional_hits)); + return addTerm(create_nearest_neighbor_term<NodeTypes>(query_tensor_name, field_name, id, weight, target_num_hits, allow_approximate, explore_additional_hits, distance_threshold)); } }; diff --git a/searchlib/src/vespa/searchlib/query/tree/queryreplicator.h b/searchlib/src/vespa/searchlib/query/tree/queryreplicator.h index 600249c3e1e..4b9226f6112 100644 --- a/searchlib/src/vespa/searchlib/query/tree/queryreplicator.h +++ b/searchlib/src/vespa/searchlib/query/tree/queryreplicator.h @@ -166,7 +166,8 @@ private: void visit(NearestNeighborTerm &node) override { replicate(node, _builder.add_nearest_neighbor_term(node.get_query_tensor_name(), node.getView(), node.getId(), node.getWeight(), node.get_target_num_hits(), - node.get_allow_approximate(), node.get_explore_additional_hits())); + node.get_allow_approximate(), node.get_explore_additional_hits(), + node.get_distance_threshold())); } }; diff --git a/searchlib/src/vespa/searchlib/query/tree/simplequery.h b/searchlib/src/vespa/searchlib/query/tree/simplequery.h index 4953f1a5b7c..db517edc348 100644 --- a/searchlib/src/vespa/searchlib/query/tree/simplequery.h +++ b/searchlib/src/vespa/searchlib/query/tree/simplequery.h @@ -106,9 +106,11 @@ struct SimpleRegExpTerm : RegExpTerm { struct SimpleNearestNeighborTerm : NearestNeighborTerm { SimpleNearestNeighborTerm(vespalib::stringref query_tensor_name, vespalib::stringref field_name, int32_t id, Weight weight, uint32_t target_num_hits, - bool allow_approximate, uint32_t explore_additional_hits) + bool allow_approximate, uint32_t explore_additional_hits, + double distance_threshold) : NearestNeighborTerm(query_tensor_name, field_name, id, weight, - target_num_hits, allow_approximate, explore_additional_hits) + target_num_hits, allow_approximate, explore_additional_hits, + distance_threshold) {} }; diff --git a/searchlib/src/vespa/searchlib/query/tree/stackdumpcreator.cpp b/searchlib/src/vespa/searchlib/query/tree/stackdumpcreator.cpp index 9af1ecee224..a006e66310c 100644 --- a/searchlib/src/vespa/searchlib/query/tree/stackdumpcreator.cpp +++ b/searchlib/src/vespa/searchlib/query/tree/stackdumpcreator.cpp @@ -265,6 +265,7 @@ class QueryNodeConverter : public QueryVisitor { appendCompressedPositiveNumber(node.get_target_num_hits()); appendCompressedPositiveNumber(node.get_allow_approximate() ? 1 : 0); appendCompressedPositiveNumber(node.get_explore_additional_hits()); + appendDouble(node.get_distance_threshold()); } public: diff --git a/searchlib/src/vespa/searchlib/query/tree/stackdumpquerycreator.h b/searchlib/src/vespa/searchlib/query/tree/stackdumpquerycreator.h index 66702fcd85c..040ac751d25 100644 --- a/searchlib/src/vespa/searchlib/query/tree/stackdumpquerycreator.h +++ b/searchlib/src/vespa/searchlib/query/tree/stackdumpquerycreator.h @@ -120,8 +120,10 @@ private: Weight weight = queryStack.GetWeight(); bool allow_approximate = queryStack.getAllowApproximate(); uint32_t explore_additional_hits = queryStack.getExploreAdditionalHits(); + double distance_threshold = queryStack.getDistanceThreshold(); builder.add_nearest_neighbor_term(query_tensor_name, field_name, id, weight, - target_num_hits, allow_approximate, explore_additional_hits); + target_num_hits, allow_approximate, explore_additional_hits, + distance_threshold); } else { vespalib::stringref term = queryStack.getTerm(); vespalib::stringref view = queryStack.getIndexName(); diff --git a/searchlib/src/vespa/searchlib/query/tree/termnodes.h b/searchlib/src/vespa/searchlib/query/tree/termnodes.h index 9af424716fb..e112fd6e295 100644 --- a/searchlib/src/vespa/searchlib/query/tree/termnodes.h +++ b/searchlib/src/vespa/searchlib/query/tree/termnodes.h @@ -130,22 +130,26 @@ private: uint32_t _target_num_hits; bool _allow_approximate; uint32_t _explore_additional_hits; + double _distance_threshold; public: NearestNeighborTerm(vespalib::stringref query_tensor_name, vespalib::stringref field_name, int32_t id, Weight weight, uint32_t target_num_hits, - bool allow_approximate, uint32_t explore_additional_hits) + bool allow_approximate, uint32_t explore_additional_hits, + double distance_threshold) : QueryNodeMixinType(field_name, id, weight), _query_tensor_name(query_tensor_name), _target_num_hits(target_num_hits), _allow_approximate(allow_approximate), - _explore_additional_hits(explore_additional_hits) + _explore_additional_hits(explore_additional_hits), + _distance_threshold(distance_threshold) {} virtual ~NearestNeighborTerm() {} const vespalib::string& get_query_tensor_name() const { return _query_tensor_name; } uint32_t get_target_num_hits() const { return _target_num_hits; } bool get_allow_approximate() const { return _allow_approximate; } uint32_t get_explore_additional_hits() const { return _explore_additional_hits; } + double get_distance_threshold() const { return _distance_threshold; } }; diff --git a/searchlib/src/vespa/searchlib/queryeval/nearest_neighbor_blueprint.cpp b/searchlib/src/vespa/searchlib/queryeval/nearest_neighbor_blueprint.cpp index d3ecffd1605..01f02748664 100644 --- a/searchlib/src/vespa/searchlib/queryeval/nearest_neighbor_blueprint.cpp +++ b/searchlib/src/vespa/searchlib/queryeval/nearest_neighbor_blueprint.cpp @@ -52,13 +52,15 @@ struct ConvertCellsSelector NearestNeighborBlueprint::NearestNeighborBlueprint(const queryeval::FieldSpec& field, const tensor::DenseTensorAttribute& attr_tensor, std::unique_ptr<Value> query_tensor, - uint32_t target_num_hits, bool approximate, uint32_t explore_additional_hits, double brute_force_limit) + uint32_t target_num_hits, bool approximate, uint32_t explore_additional_hits, + double distance_threshold, double brute_force_limit) : ComplexLeafBlueprint(field), _attr_tensor(attr_tensor), _query_tensor(std::move(query_tensor)), _target_num_hits(target_num_hits), _approximate(approximate), _explore_additional_hits(explore_additional_hits), + _distance_threshold(std::numeric_limits<double>::max()), _brute_force_limit(brute_force_limit), _fallback_dist_fun(), _distance_heap(target_num_hits), @@ -72,9 +74,15 @@ NearestNeighborBlueprint::NearestNeighborBlueprint(const queryeval::FieldSpec& f fixup_fun(_query_tensor, _attr_tensor.getTensorType()); _fallback_dist_fun = search::tensor::make_distance_function(_attr_tensor.getConfig().distance_metric(), rct); _dist_fun = _fallback_dist_fun.get(); + assert(_dist_fun); auto nns_index = _attr_tensor.nearest_neighbor_index(); if (nns_index) { _dist_fun = nns_index->distance_function(); + assert(_dist_fun); + } + if (distance_threshold < std::numeric_limits<double>::max()) { + _distance_threshold = _dist_fun->convert_threshold(distance_threshold); + _distance_heap.set_distance_threshold(_distance_threshold); } uint32_t est_hits = _attr_tensor.getNumDocs(); setEstimate(HitEstimate(est_hits, false)); @@ -127,9 +135,9 @@ NearestNeighborBlueprint::perform_top_k() uint32_t k = _target_num_hits; if (_global_filter->has_filter()) { auto filter = _global_filter->filter(); - _found_hits = nns_index->find_top_k_with_filter(k, lhs, *filter, k + _explore_additional_hits); + _found_hits = nns_index->find_top_k_with_filter(k, lhs, *filter, k + _explore_additional_hits, _distance_threshold); } else { - _found_hits = nns_index->find_top_k(k, lhs, k + _explore_additional_hits); + _found_hits = nns_index->find_top_k(k, lhs, k + _explore_additional_hits, _distance_threshold); } } } diff --git a/searchlib/src/vespa/searchlib/queryeval/nearest_neighbor_blueprint.h b/searchlib/src/vespa/searchlib/queryeval/nearest_neighbor_blueprint.h index a8a0ff19246..aad43c923a2 100644 --- a/searchlib/src/vespa/searchlib/queryeval/nearest_neighbor_blueprint.h +++ b/searchlib/src/vespa/searchlib/queryeval/nearest_neighbor_blueprint.h @@ -24,6 +24,7 @@ private: uint32_t _target_num_hits; bool _approximate; uint32_t _explore_additional_hits; + double _distance_threshold; double _brute_force_limit; search::tensor::DistanceFunction::UP _fallback_dist_fun; const search::tensor::DistanceFunction *_dist_fun; @@ -36,7 +37,9 @@ public: NearestNeighborBlueprint(const queryeval::FieldSpec& field, const tensor::DenseTensorAttribute& attr_tensor, std::unique_ptr<vespalib::eval::Value> query_tensor, - uint32_t target_num_hits, bool approximate, uint32_t explore_additional_hits, double brute_force_limit); + uint32_t target_num_hits, bool approximate, uint32_t explore_additional_hits, + double distance_threshold, + double brute_force_limit); NearestNeighborBlueprint(const NearestNeighborBlueprint&) = delete; NearestNeighborBlueprint& operator=(const NearestNeighborBlueprint&) = delete; ~NearestNeighborBlueprint(); @@ -45,6 +48,7 @@ public: uint32_t get_target_num_hits() const { return _target_num_hits; } void set_global_filter(const GlobalFilter &global_filter) override; bool may_approximate() const { return _approximate; } + double get_distance_threshold() const { return _distance_threshold; } std::unique_ptr<SearchIterator> createLeafSearch(const search::fef::TermFieldMatchDataArray& tfmda, bool strict) const override; diff --git a/searchlib/src/vespa/searchlib/queryeval/nearest_neighbor_distance_heap.h b/searchlib/src/vespa/searchlib/queryeval/nearest_neighbor_distance_heap.h index 3937dfba2ca..b7bdffd31c1 100644 --- a/searchlib/src/vespa/searchlib/queryeval/nearest_neighbor_distance_heap.h +++ b/searchlib/src/vespa/searchlib/queryeval/nearest_neighbor_distance_heap.h @@ -15,15 +15,22 @@ class NearestNeighborDistanceHeap { private: std::mutex _lock; size_t _size; + double _distance_threshold; vespalib::PriorityQueue<double, std::greater<double>> _priQ; public: - explicit NearestNeighborDistanceHeap(size_t maxSize) : _size(maxSize), _priQ() { + explicit NearestNeighborDistanceHeap(size_t maxSize) + : _size(maxSize), _distance_threshold(std::numeric_limits<double>::max()), + _priQ() + { _priQ.reserve(maxSize); } + void set_distance_threshold(double distance_threshold) { + _distance_threshold = distance_threshold; + } double distanceLimit() { std::lock_guard<std::mutex> guard(_lock); if (_priQ.size() < _size) { - return std::numeric_limits<double>::max(); + return _distance_threshold; } return _priQ.front(); } diff --git a/searchlib/src/vespa/searchlib/tensor/hnsw_index.cpp b/searchlib/src/vespa/searchlib/tensor/hnsw_index.cpp index 6488b525b7c..44b2ff2b7f1 100644 --- a/searchlib/src/vespa/searchlib/tensor/hnsw_index.cpp +++ b/searchlib/src/vespa/searchlib/tensor/hnsw_index.cpp @@ -281,6 +281,7 @@ HnswIndex::HnswIndex(const DocVectorAccess& vectors, DistanceFunction::UP distan _level_generator(std::move(level_generator)), _cfg(cfg) { + assert(_distance_func); } HnswIndex::~HnswIndex() = default; @@ -534,7 +535,8 @@ struct NeighborsByDocId { std::vector<NearestNeighborIndex::Neighbor> HnswIndex::top_k_by_docid(uint32_t k, TypedCells vector, - const BitVector *filter, uint32_t explore_k) const + const BitVector *filter, uint32_t explore_k, + double distance_threshold) const { std::vector<Neighbor> result; FurthestPriQ candidates = top_k_candidates(vector, std::max(k, explore_k), filter); @@ -543,6 +545,7 @@ HnswIndex::top_k_by_docid(uint32_t k, TypedCells vector, } result.reserve(candidates.size()); for (const HnswCandidate & hit : candidates.peek()) { + if (hit.distance > distance_threshold) continue; result.emplace_back(hit.docid, hit.distance); } std::sort(result.begin(), result.end(), NeighborsByDocId()); @@ -550,16 +553,18 @@ HnswIndex::top_k_by_docid(uint32_t k, TypedCells vector, } std::vector<NearestNeighborIndex::Neighbor> -HnswIndex::find_top_k(uint32_t k, TypedCells vector, uint32_t explore_k) const +HnswIndex::find_top_k(uint32_t k, TypedCells vector, uint32_t explore_k, + double distance_threshold) const { - return top_k_by_docid(k, vector, nullptr, explore_k); + return top_k_by_docid(k, vector, nullptr, explore_k, distance_threshold); } std::vector<NearestNeighborIndex::Neighbor> HnswIndex::find_top_k_with_filter(uint32_t k, TypedCells vector, - const BitVector &filter, uint32_t explore_k) const + const BitVector &filter, uint32_t explore_k, + double distance_threshold) const { - return top_k_by_docid(k, vector, &filter, explore_k); + return top_k_by_docid(k, vector, &filter, explore_k, distance_threshold); } FurthestPriQ diff --git a/searchlib/src/vespa/searchlib/tensor/hnsw_index.h b/searchlib/src/vespa/searchlib/tensor/hnsw_index.h index c07a0642b2e..5bd9d17adc3 100644 --- a/searchlib/src/vespa/searchlib/tensor/hnsw_index.h +++ b/searchlib/src/vespa/searchlib/tensor/hnsw_index.h @@ -123,7 +123,8 @@ protected: void search_layer(const TypedCells& input, uint32_t neighbors_to_find, FurthestPriQ& found_neighbors, uint32_t level, const search::BitVector *filter = nullptr) const; std::vector<Neighbor> top_k_by_docid(uint32_t k, TypedCells vector, - const BitVector *filter, uint32_t explore_k) const; + const BitVector *filter, uint32_t explore_k, + double distance_threshold) const; struct PreparedAddDoc : public PrepareResult { using ReadGuard = vespalib::GenerationHandler::Guard; @@ -166,9 +167,11 @@ public: std::unique_ptr<NearestNeighborIndexSaver> make_saver() const override; bool load(const fileutil::LoadedBuffer& buf) override; - std::vector<Neighbor> find_top_k(uint32_t k, TypedCells vector, uint32_t explore_k) const override; + std::vector<Neighbor> find_top_k(uint32_t k, TypedCells vector, uint32_t explore_k, + double distance_threshold) const override; std::vector<Neighbor> find_top_k_with_filter(uint32_t k, TypedCells vector, - const BitVector &filter, uint32_t explore_k) const override; + const BitVector &filter, uint32_t explore_k, + double distance_threshold) const override; const DistanceFunction *distance_function() const override { return _distance_func.get(); } FurthestPriQ top_k_candidates(const TypedCells &vector, uint32_t k, const BitVector *filter) const; diff --git a/searchlib/src/vespa/searchlib/tensor/nearest_neighbor_index.h b/searchlib/src/vespa/searchlib/tensor/nearest_neighbor_index.h index c14da0d058f..fd37cf80720 100644 --- a/searchlib/src/vespa/searchlib/tensor/nearest_neighbor_index.h +++ b/searchlib/src/vespa/searchlib/tensor/nearest_neighbor_index.h @@ -73,13 +73,15 @@ public: virtual std::vector<Neighbor> find_top_k(uint32_t k, vespalib::eval::TypedCells vector, - uint32_t explore_k) const = 0; + uint32_t explore_k, + double distance_threshold) const = 0; // only return neighbors where the corresponding filter bit is set virtual std::vector<Neighbor> find_top_k_with_filter(uint32_t k, vespalib::eval::TypedCells vector, const BitVector &filter, - uint32_t explore_k) const = 0; + uint32_t explore_k, + double distance_threshold) const = 0; virtual const DistanceFunction *distance_function() const = 0; }; |