summaryrefslogtreecommitdiffstats
path: root/searchlib
diff options
context:
space:
mode:
authorArne Juul <arnej@verizonmedia.com>2021-01-07 12:42:17 +0000
committerArne Juul <arnej@verizonmedia.com>2021-01-08 10:56:14 +0000
commitcf199f338efafad8c0af7de48094bd3d0037b96a (patch)
tree2c63ecee902b297006edb41a67925e721bfddff6 /searchlib
parent8aa9ffda4324ddd5baff87be858063c6399a26ca (diff)
add distanceThreshold option for nearestNeighbor operator
Diffstat (limited to 'searchlib')
-rw-r--r--searchlib/src/tests/attribute/searchable/attributeblueprint_test.cpp2
-rw-r--r--searchlib/src/tests/attribute/tensorattribute/tensorattribute_test.cpp20
-rw-r--r--searchlib/src/tests/query/query_visitor_test.cpp2
-rw-r--r--searchlib/src/tests/query/querybuilder_test.cpp7
-rw-r--r--searchlib/src/tests/queryeval/nearest_neighbor/nearest_neighbor_test.cpp16
-rw-r--r--searchlib/src/tests/tensor/distance_functions/distance_functions_test.cpp36
-rw-r--r--searchlib/src/tests/tensor/hnsw_index/hnsw_index_test.cpp11
-rw-r--r--searchlib/src/vespa/searchlib/attribute/attribute_blueprint_factory.cpp1
-rw-r--r--searchlib/src/vespa/searchlib/parsequery/stackdumpiterator.cpp1
-rw-r--r--searchlib/src/vespa/searchlib/parsequery/stackdumpiterator.h1
-rw-r--r--searchlib/src/vespa/searchlib/query/tree/querybuilder.h12
-rw-r--r--searchlib/src/vespa/searchlib/query/tree/queryreplicator.h3
-rw-r--r--searchlib/src/vespa/searchlib/query/tree/simplequery.h6
-rw-r--r--searchlib/src/vespa/searchlib/query/tree/stackdumpcreator.cpp1
-rw-r--r--searchlib/src/vespa/searchlib/query/tree/stackdumpquerycreator.h4
-rw-r--r--searchlib/src/vespa/searchlib/query/tree/termnodes.h8
-rw-r--r--searchlib/src/vespa/searchlib/queryeval/nearest_neighbor_blueprint.cpp14
-rw-r--r--searchlib/src/vespa/searchlib/queryeval/nearest_neighbor_blueprint.h6
-rw-r--r--searchlib/src/vespa/searchlib/queryeval/nearest_neighbor_distance_heap.h11
-rw-r--r--searchlib/src/vespa/searchlib/tensor/hnsw_index.cpp15
-rw-r--r--searchlib/src/vespa/searchlib/tensor/hnsw_index.h9
-rw-r--r--searchlib/src/vespa/searchlib/tensor/nearest_neighbor_index.h6
22 files changed, 153 insertions, 39 deletions
diff --git a/searchlib/src/tests/attribute/searchable/attributeblueprint_test.cpp b/searchlib/src/tests/attribute/searchable/attributeblueprint_test.cpp
index 51b4f1d760d..855510d0457 100644
--- a/searchlib/src/tests/attribute/searchable/attributeblueprint_test.cpp
+++ b/searchlib/src/tests/attribute/searchable/attributeblueprint_test.cpp
@@ -341,7 +341,7 @@ public:
request_ctx.set_query_tensor("query_tensor", tensor_spec);
}
Blueprint::UP create_blueprint() {
- query::NearestNeighborTerm term("query_tensor", attr_name, 0, Weight(0), 7, true, 33);
+ query::NearestNeighborTerm term("query_tensor", attr_name, 0, Weight(0), 7, true, 33, 100100.25);
return BlueprintFactoryFixture::create_blueprint(term);
}
};
diff --git a/searchlib/src/tests/attribute/tensorattribute/tensorattribute_test.cpp b/searchlib/src/tests/attribute/tensorattribute/tensorattribute_test.cpp
index 7b597af417d..cbdb2c9bd22 100644
--- a/searchlib/src/tests/attribute/tensorattribute/tensorattribute_test.cpp
+++ b/searchlib/src/tests/attribute/tensorattribute/tensorattribute_test.cpp
@@ -13,6 +13,7 @@
#include <vespa/searchlib/tensor/dense_tensor_attribute.h>
#include <vespa/searchlib/tensor/direct_tensor_attribute.h>
#include <vespa/searchlib/tensor/doc_vector_access.h>
+#include <vespa/searchlib/tensor/distance_functions.h>
#include <vespa/searchlib/tensor/hnsw_index.h>
#include <vespa/searchlib/tensor/nearest_neighbor_index.h>
#include <vespa/searchlib/tensor/nearest_neighbor_index_factory.h>
@@ -206,24 +207,32 @@ public:
_index_value = (reinterpret_cast<const int*>(buf.buffer()))[0];
return true;
}
- std::vector<Neighbor> find_top_k(uint32_t k, vespalib::eval::TypedCells vector, uint32_t explore_k) const override {
+ std::vector<Neighbor> find_top_k(uint32_t k, vespalib::eval::TypedCells vector, uint32_t explore_k,
+ double distance_threshold) const override
+ {
(void) k;
(void) vector;
(void) explore_k;
+ (void) distance_threshold;
return std::vector<Neighbor>();
}
std::vector<Neighbor> find_top_k_with_filter(uint32_t k, vespalib::eval::TypedCells vector,
- const search::BitVector& filter, uint32_t explore_k) const override
+ const search::BitVector& filter, uint32_t explore_k,
+ double distance_threshold) const override
{
(void) k;
(void) vector;
(void) explore_k;
(void) filter;
+ (void) distance_threshold;
return std::vector<Neighbor>();
}
- const search::tensor::DistanceFunction *distance_function() const override { return nullptr; }
+ const search::tensor::DistanceFunction *distance_function() const override {
+ static search::tensor::SquaredEuclideanDistance<double> my_dist_fun;
+ return &my_dist_fun;
+ }
};
class MockNearestNeighborIndexFactory : public NearestNeighborIndexFactory {
@@ -914,9 +923,12 @@ public:
field,
as_dense_tensor(),
createDenseTensor(vec_2d(17, 42)),
- 3, true, 5, brute_force_limit);
+ 3, true, 5,
+ 100100.25,
+ brute_force_limit);
EXPECT_EQUAL(11u, bp->getState().estimate().estHits);
EXPECT_TRUE(bp->may_approximate());
+ EXPECT_EQUAL(100100.25 * 100100.25, bp->get_distance_threshold());
return bp;
}
};
diff --git a/searchlib/src/tests/query/query_visitor_test.cpp b/searchlib/src/tests/query/query_visitor_test.cpp
index 8441dc2227f..946ad17352d 100644
--- a/searchlib/src/tests/query/query_visitor_test.cpp
+++ b/searchlib/src/tests/query/query_visitor_test.cpp
@@ -99,7 +99,7 @@ void Test::requireThatAllNodesCanBeVisited() {
checkVisit<SuffixTerm>(new SimpleSuffixTerm("t", "field", 0, Weight(0)));
checkVisit<PredicateQuery>(new SimplePredicateQuery(PredicateQueryTerm::UP(), "field", 0, Weight(0)));
checkVisit<RegExpTerm>(new SimpleRegExpTerm("t", "field", 0, Weight(0)));
- checkVisit<NearestNeighborTerm>(new SimpleNearestNeighborTerm("query_tensor", "doc_tensor", 0, Weight(0), 123, true, 321));
+ checkVisit<NearestNeighborTerm>(new SimpleNearestNeighborTerm("query_tensor", "doc_tensor", 0, Weight(0), 123, true, 321, 100100.25));
}
} // namespace
diff --git a/searchlib/src/tests/query/querybuilder_test.cpp b/searchlib/src/tests/query/querybuilder_test.cpp
index 5a5a5eafb2c..30b4d2ae264 100644
--- a/searchlib/src/tests/query/querybuilder_test.cpp
+++ b/searchlib/src/tests/query/querybuilder_test.cpp
@@ -110,7 +110,7 @@ Node::UP createQueryTree() {
builder.addStringTerm(str[5], view[5], id[5], weight[6]);
builder.addStringTerm(str[6], view[6], id[6], weight[7]);
}
- builder.add_nearest_neighbor_term("query_tensor", "doc_tensor", id[3], weight[5], 7, true, 33);
+ builder.add_nearest_neighbor_term("query_tensor", "doc_tensor", id[3], weight[5], 7, true, 33, 100100.25);
}
Node::UP node = builder.build();
ASSERT_TRUE(node.get());
@@ -395,8 +395,9 @@ struct MyRegExpTerm : RegExpTerm {
struct MyNearestNeighborTerm : NearestNeighborTerm {
MyNearestNeighborTerm(vespalib::stringref query_tensor_name, vespalib::stringref field_name,
int32_t i, Weight w, uint32_t target_num_hits,
- bool allow_approximate, uint32_t explore_additional_hits)
- : NearestNeighborTerm(query_tensor_name, field_name, i, w, target_num_hits, allow_approximate, explore_additional_hits)
+ bool allow_approximate, uint32_t explore_additional_hits,
+ double distance_threshold)
+ : NearestNeighborTerm(query_tensor_name, field_name, i, w, target_num_hits, allow_approximate, explore_additional_hits, distance_threshold)
{}
};
diff --git a/searchlib/src/tests/queryeval/nearest_neighbor/nearest_neighbor_test.cpp b/searchlib/src/tests/queryeval/nearest_neighbor/nearest_neighbor_test.cpp
index ad450a91f33..09790e7e360 100644
--- a/searchlib/src/tests/queryeval/nearest_neighbor/nearest_neighbor_test.cpp
+++ b/searchlib/src/tests/queryeval/nearest_neighbor/nearest_neighbor_test.cpp
@@ -121,11 +121,12 @@ struct Fixture
};
template <bool strict>
-SimpleResult find_matches(Fixture &env, const Value &qtv) {
+SimpleResult find_matches(Fixture &env, const Value &qtv, double threshold = std::numeric_limits<double>::max()) {
auto md = MatchData::makeTestInstance(2, 2);
auto &tfmd = *(md->resolveTermField(0));
auto &attr = *(env._tensorAttr);
NearestNeighborDistanceHeap dh(2);
+ dh.set_distance_threshold(env.dist_fun()->convert_threshold(threshold));
const BitVector *filter = env._global_filter.get();
auto search = NearestNeighborIterator::create(strict, tfmd, qtv, attr, dh, filter, env.dist_fun());
if (strict) {
@@ -159,6 +160,19 @@ verify_iterator_returns_expected_results(const vespalib::string& attribute_tenso
EXPECT_EQUAL(result, farExpect);
result = find_matches<false>(fixture, *farTensor);
EXPECT_EQUAL(result, farExpect);
+
+ SimpleResult null_thr5_exp({1,4,6});
+ result = find_matches<true>(fixture, *nullTensor, 5.0);
+ EXPECT_EQUAL(result, null_thr5_exp);
+ result = find_matches<false>(fixture, *nullTensor, 5.0);
+ EXPECT_EQUAL(result, null_thr5_exp);
+
+ SimpleResult far_thr4_exp({2,5});
+ result = find_matches<true>(fixture, *farTensor, 4.0);
+ EXPECT_EQUAL(result, far_thr4_exp);
+ result = find_matches<false>(fixture, *farTensor, 4.0);
+ EXPECT_EQUAL(result, far_thr4_exp);
+
}
TEST("require that NearestNeighborIterator returns expected results") {
diff --git a/searchlib/src/tests/tensor/distance_functions/distance_functions_test.cpp b/searchlib/src/tests/tensor/distance_functions/distance_functions_test.cpp
index 06fb95089fd..ee0a2aff80e 100644
--- a/searchlib/src/tests/tensor/distance_functions/distance_functions_test.cpp
+++ b/searchlib/src/tests/tensor/distance_functions/distance_functions_test.cpp
@@ -24,10 +24,12 @@ void verify_geo_miles(const DistanceFunction *dist_fun,
TypedCells t2(p2);
double abstract_distance = dist_fun->calc(t1, t2);
double raw_score = dist_fun->to_rawscore(abstract_distance);
- double m = ((1.0/raw_score)-1.0);
- double d_miles = m / 1.609344;
+ double km = ((1.0/raw_score)-1.0);
+ double d_miles = km / 1.609344;
EXPECT_GE(d_miles, exp_miles*0.99);
EXPECT_LE(d_miles, exp_miles*1.01);
+ double threshold = dist_fun->convert_threshold(km);
+ EXPECT_DOUBLE_EQ(threshold, abstract_distance);
}
@@ -50,6 +52,10 @@ TEST(DistanceFunctionsTest, euclidean_gives_expected_score)
double d12 = euclid->calc(t(p1), t(p2));
EXPECT_EQ(d12, 2.0);
EXPECT_DOUBLE_EQ(euclid->to_rawscore(d12), 1.0/(1.0 + sqrt(2.0)));
+ double threshold = euclid->convert_threshold(8.0);
+ EXPECT_EQ(threshold, 64.0);
+ threshold = euclid->convert_threshold(0.5);
+ EXPECT_EQ(threshold, 0.25);
}
TEST(DistanceFunctionsTest, angular_gives_expected_score)
@@ -75,19 +81,28 @@ TEST(DistanceFunctionsTest, angular_gives_expected_score)
EXPECT_DOUBLE_EQ(a23, 1.0);
EXPECT_FLOAT_EQ(angular->to_rawscore(a12), 1.0/(1.0 + pi/2));
+ double threshold = angular->convert_threshold(pi/2);
+ EXPECT_DOUBLE_EQ(threshold, 1.0);
+
double a14 = angular->calc(t(p1), t(p4));
double a24 = angular->calc(t(p2), t(p4));
EXPECT_FLOAT_EQ(a14, 0.5);
EXPECT_FLOAT_EQ(a24, 0.5);
EXPECT_FLOAT_EQ(angular->to_rawscore(a14), 1.0/(1.0 + pi/3));
+ threshold = angular->convert_threshold(pi/3);
+ EXPECT_DOUBLE_EQ(threshold, 0.5);
double a34 = angular->calc(t(p3), t(p4));
EXPECT_FLOAT_EQ(a34, (1.0 - 0.707107));
EXPECT_FLOAT_EQ(angular->to_rawscore(a34), 1.0/(1.0 + pi/4));
+ threshold = angular->convert_threshold(pi/4);
+ EXPECT_FLOAT_EQ(threshold, a34);
double a25 = angular->calc(t(p2), t(p5));
EXPECT_DOUBLE_EQ(a25, 2.0);
EXPECT_FLOAT_EQ(angular->to_rawscore(a25), 1.0/(1.0 + pi));
+ threshold = angular->convert_threshold(pi);
+ EXPECT_FLOAT_EQ(threshold, 2.0);
double a44 = angular->calc(t(p4), t(p4));
EXPECT_GE(a44, 0.0);
@@ -98,6 +113,8 @@ TEST(DistanceFunctionsTest, angular_gives_expected_score)
EXPECT_GE(a66, 0.0);
EXPECT_LT(a66, 0.000001);
EXPECT_FLOAT_EQ(angular->to_rawscore(a66), 1.0);
+ threshold = angular->convert_threshold(0.0);
+ EXPECT_FLOAT_EQ(threshold, 0.0);
double a16 = angular->calc(t(p1), t(p6));
double a26 = angular->calc(t(p2), t(p6));
@@ -127,6 +144,7 @@ TEST(DistanceFunctionsTest, innerproduct_gives_expected_score)
EXPECT_DOUBLE_EQ(i12, 1.0);
EXPECT_DOUBLE_EQ(i13, 1.0);
EXPECT_DOUBLE_EQ(i23, 1.0);
+
double i14 = innerproduct->calc(t(p1), t(p4));
double i24 = innerproduct->calc(t(p2), t(p4));
EXPECT_DOUBLE_EQ(i14, 0.5);
@@ -140,6 +158,13 @@ TEST(DistanceFunctionsTest, innerproduct_gives_expected_score)
double i44 = innerproduct->calc(t(p4), t(p4));
EXPECT_GE(i44, 0.0);
EXPECT_LT(i44, 0.000001);
+
+ double threshold = innerproduct->convert_threshold(0.25);
+ EXPECT_DOUBLE_EQ(threshold, 0.25);
+ threshold = innerproduct->convert_threshold(0.5);
+ EXPECT_DOUBLE_EQ(threshold, 0.5);
+ threshold = innerproduct->convert_threshold(1.0);
+ EXPECT_DOUBLE_EQ(threshold, 1.0);
}
TEST(DistanceFunctionsTest, hamming_gives_expected_score)
@@ -180,6 +205,13 @@ TEST(DistanceFunctionsTest, hamming_gives_expected_score)
double d25 = hamming->calc(t(points[2]), t(points[5]));
EXPECT_EQ(d25, 1.0);
EXPECT_DOUBLE_EQ(hamming->to_rawscore(d25), 1.0/(1.0 + 1.0));
+
+ double threshold = hamming->convert_threshold(0.25);
+ EXPECT_DOUBLE_EQ(threshold, 0.25);
+ threshold = hamming->convert_threshold(0.5);
+ EXPECT_DOUBLE_EQ(threshold, 0.5);
+ threshold = hamming->convert_threshold(1.0);
+ EXPECT_DOUBLE_EQ(threshold, 1.0);
}
TEST(GeoDegreesTest, gives_expected_score)
diff --git a/searchlib/src/tests/tensor/hnsw_index/hnsw_index_test.cpp b/searchlib/src/tests/tensor/hnsw_index/hnsw_index_test.cpp
index acc157709c0..d081c299a43 100644
--- a/searchlib/src/tests/tensor/hnsw_index/hnsw_index_test.cpp
+++ b/searchlib/src/tests/tensor/hnsw_index/hnsw_index_test.cpp
@@ -144,11 +144,20 @@ public:
if (exp_hits.size() == k) {
std::vector<uint32_t> expected_by_docid = exp_hits;
std::sort(expected_by_docid.begin(), expected_by_docid.end());
- auto got_by_docid = index->find_top_k(k, qv, k);
+ auto got_by_docid = index->find_top_k(k, qv, k, 100100.25);
for (idx = 0; idx < k; ++idx) {
EXPECT_EQ(expected_by_docid[idx], got_by_docid[idx].docid);
}
}
+ if ((rv.size() > 1) && (rv[0].distance < rv[1].distance)) {
+ double thr = (rv[0].distance + rv[1].distance) * 0.5;
+ auto got_by_docid = index->find_top_k_with_filter(k, qv, *global_filter, k, thr);
+ for (const auto & hit : got_by_docid) {
+ printf("hit docid=%u dist=%g (thr %g)\n", hit.docid, hit.distance, thr);
+ }
+ EXPECT_EQ(got_by_docid.size(), 1);
+ EXPECT_EQ(got_by_docid[0].docid, rv[0].docid);
+ }
}
};
diff --git a/searchlib/src/vespa/searchlib/attribute/attribute_blueprint_factory.cpp b/searchlib/src/vespa/searchlib/attribute/attribute_blueprint_factory.cpp
index dfcbfbbbe2b..70a59f1575a 100644
--- a/searchlib/src/vespa/searchlib/attribute/attribute_blueprint_factory.cpp
+++ b/searchlib/src/vespa/searchlib/attribute/attribute_blueprint_factory.cpp
@@ -736,6 +736,7 @@ public:
n.get_target_num_hits(),
n.get_allow_approximate(),
n.get_explore_additional_hits(),
+ n.get_distance_threshold(),
getRequestContext().get_attribute_blueprint_params().nearest_neighbor_brute_force_limit));
}
};
diff --git a/searchlib/src/vespa/searchlib/parsequery/stackdumpiterator.cpp b/searchlib/src/vespa/searchlib/parsequery/stackdumpiterator.cpp
index 6039a86580c..6feec9bbba2 100644
--- a/searchlib/src/vespa/searchlib/parsequery/stackdumpiterator.cpp
+++ b/searchlib/src/vespa/searchlib/parsequery/stackdumpiterator.cpp
@@ -273,6 +273,7 @@ SimpleQueryStackDumpIterator::next()
_extraIntArg1 = readCompressedPositiveInt(p); // targetNumHits
_extraIntArg2 = readCompressedPositiveInt(p); // allow_approximate
_extraIntArg3 = readCompressedPositiveInt(p); // explore_additional_hits
+ _extraDoubleArg4 = read_double(p); // distance threshold
_currArity = 0;
} catch (...) {
return false;
diff --git a/searchlib/src/vespa/searchlib/parsequery/stackdumpiterator.h b/searchlib/src/vespa/searchlib/parsequery/stackdumpiterator.h
index d60765f3fe1..301929c8919 100644
--- a/searchlib/src/vespa/searchlib/parsequery/stackdumpiterator.h
+++ b/searchlib/src/vespa/searchlib/parsequery/stackdumpiterator.h
@@ -119,6 +119,7 @@ public:
uint32_t getNearDistance() const { return _extraIntArg1; }
uint32_t getTargetNumHits() const { return _extraIntArg1; }
+ double getDistanceThreshold() const { return _extraDoubleArg4; }
double getScoreThreshold() const { return _extraDoubleArg4; }
double getThresholdBoostFactor() const { return _extraDoubleArg5; }
bool getAllowApproximate() const { return (_extraIntArg2 != 0); }
diff --git a/searchlib/src/vespa/searchlib/query/tree/querybuilder.h b/searchlib/src/vespa/searchlib/query/tree/querybuilder.h
index 8e6f2944ec9..8392730cd29 100644
--- a/searchlib/src/vespa/searchlib/query/tree/querybuilder.h
+++ b/searchlib/src/vespa/searchlib/query/tree/querybuilder.h
@@ -206,10 +206,12 @@ template <class NodeTypes>
typename NodeTypes::NearestNeighborTerm *
create_nearest_neighbor_term(vespalib::stringref query_tensor_name, vespalib::stringref field_name,
int32_t id, Weight weight, uint32_t target_num_hits,
- bool allow_approximate, uint32_t explore_additional_hits)
+ bool allow_approximate, uint32_t explore_additional_hits,
+ double distance_threshold)
{
return new typename NodeTypes::NearestNeighborTerm(query_tensor_name, field_name, id, weight,
- target_num_hits, allow_approximate, explore_additional_hits);
+ target_num_hits, allow_approximate, explore_additional_hits,
+ distance_threshold);
}
template <class NodeTypes>
@@ -321,9 +323,11 @@ public:
}
typename NodeTypes::NearestNeighborTerm &add_nearest_neighbor_term(stringref query_tensor_name, stringref field_name,
int32_t id, Weight weight, uint32_t target_num_hits,
- bool allow_approximate, uint32_t explore_additional_hits) {
+ bool allow_approximate, uint32_t explore_additional_hits,
+ double distance_threshold)
+ {
adjustWeight(weight);
- return addTerm(create_nearest_neighbor_term<NodeTypes>(query_tensor_name, field_name, id, weight, target_num_hits, allow_approximate, explore_additional_hits));
+ return addTerm(create_nearest_neighbor_term<NodeTypes>(query_tensor_name, field_name, id, weight, target_num_hits, allow_approximate, explore_additional_hits, distance_threshold));
}
};
diff --git a/searchlib/src/vespa/searchlib/query/tree/queryreplicator.h b/searchlib/src/vespa/searchlib/query/tree/queryreplicator.h
index 600249c3e1e..4b9226f6112 100644
--- a/searchlib/src/vespa/searchlib/query/tree/queryreplicator.h
+++ b/searchlib/src/vespa/searchlib/query/tree/queryreplicator.h
@@ -166,7 +166,8 @@ private:
void visit(NearestNeighborTerm &node) override {
replicate(node, _builder.add_nearest_neighbor_term(node.get_query_tensor_name(), node.getView(),
node.getId(), node.getWeight(), node.get_target_num_hits(),
- node.get_allow_approximate(), node.get_explore_additional_hits()));
+ node.get_allow_approximate(), node.get_explore_additional_hits(),
+ node.get_distance_threshold()));
}
};
diff --git a/searchlib/src/vespa/searchlib/query/tree/simplequery.h b/searchlib/src/vespa/searchlib/query/tree/simplequery.h
index 4953f1a5b7c..db517edc348 100644
--- a/searchlib/src/vespa/searchlib/query/tree/simplequery.h
+++ b/searchlib/src/vespa/searchlib/query/tree/simplequery.h
@@ -106,9 +106,11 @@ struct SimpleRegExpTerm : RegExpTerm {
struct SimpleNearestNeighborTerm : NearestNeighborTerm {
SimpleNearestNeighborTerm(vespalib::stringref query_tensor_name, vespalib::stringref field_name,
int32_t id, Weight weight, uint32_t target_num_hits,
- bool allow_approximate, uint32_t explore_additional_hits)
+ bool allow_approximate, uint32_t explore_additional_hits,
+ double distance_threshold)
: NearestNeighborTerm(query_tensor_name, field_name, id, weight,
- target_num_hits, allow_approximate, explore_additional_hits)
+ target_num_hits, allow_approximate, explore_additional_hits,
+ distance_threshold)
{}
};
diff --git a/searchlib/src/vespa/searchlib/query/tree/stackdumpcreator.cpp b/searchlib/src/vespa/searchlib/query/tree/stackdumpcreator.cpp
index 9af1ecee224..a006e66310c 100644
--- a/searchlib/src/vespa/searchlib/query/tree/stackdumpcreator.cpp
+++ b/searchlib/src/vespa/searchlib/query/tree/stackdumpcreator.cpp
@@ -265,6 +265,7 @@ class QueryNodeConverter : public QueryVisitor {
appendCompressedPositiveNumber(node.get_target_num_hits());
appendCompressedPositiveNumber(node.get_allow_approximate() ? 1 : 0);
appendCompressedPositiveNumber(node.get_explore_additional_hits());
+ appendDouble(node.get_distance_threshold());
}
public:
diff --git a/searchlib/src/vespa/searchlib/query/tree/stackdumpquerycreator.h b/searchlib/src/vespa/searchlib/query/tree/stackdumpquerycreator.h
index 66702fcd85c..040ac751d25 100644
--- a/searchlib/src/vespa/searchlib/query/tree/stackdumpquerycreator.h
+++ b/searchlib/src/vespa/searchlib/query/tree/stackdumpquerycreator.h
@@ -120,8 +120,10 @@ private:
Weight weight = queryStack.GetWeight();
bool allow_approximate = queryStack.getAllowApproximate();
uint32_t explore_additional_hits = queryStack.getExploreAdditionalHits();
+ double distance_threshold = queryStack.getDistanceThreshold();
builder.add_nearest_neighbor_term(query_tensor_name, field_name, id, weight,
- target_num_hits, allow_approximate, explore_additional_hits);
+ target_num_hits, allow_approximate, explore_additional_hits,
+ distance_threshold);
} else {
vespalib::stringref term = queryStack.getTerm();
vespalib::stringref view = queryStack.getIndexName();
diff --git a/searchlib/src/vespa/searchlib/query/tree/termnodes.h b/searchlib/src/vespa/searchlib/query/tree/termnodes.h
index 9af424716fb..e112fd6e295 100644
--- a/searchlib/src/vespa/searchlib/query/tree/termnodes.h
+++ b/searchlib/src/vespa/searchlib/query/tree/termnodes.h
@@ -130,22 +130,26 @@ private:
uint32_t _target_num_hits;
bool _allow_approximate;
uint32_t _explore_additional_hits;
+ double _distance_threshold;
public:
NearestNeighborTerm(vespalib::stringref query_tensor_name, vespalib::stringref field_name,
int32_t id, Weight weight, uint32_t target_num_hits,
- bool allow_approximate, uint32_t explore_additional_hits)
+ bool allow_approximate, uint32_t explore_additional_hits,
+ double distance_threshold)
: QueryNodeMixinType(field_name, id, weight),
_query_tensor_name(query_tensor_name),
_target_num_hits(target_num_hits),
_allow_approximate(allow_approximate),
- _explore_additional_hits(explore_additional_hits)
+ _explore_additional_hits(explore_additional_hits),
+ _distance_threshold(distance_threshold)
{}
virtual ~NearestNeighborTerm() {}
const vespalib::string& get_query_tensor_name() const { return _query_tensor_name; }
uint32_t get_target_num_hits() const { return _target_num_hits; }
bool get_allow_approximate() const { return _allow_approximate; }
uint32_t get_explore_additional_hits() const { return _explore_additional_hits; }
+ double get_distance_threshold() const { return _distance_threshold; }
};
diff --git a/searchlib/src/vespa/searchlib/queryeval/nearest_neighbor_blueprint.cpp b/searchlib/src/vespa/searchlib/queryeval/nearest_neighbor_blueprint.cpp
index d3ecffd1605..01f02748664 100644
--- a/searchlib/src/vespa/searchlib/queryeval/nearest_neighbor_blueprint.cpp
+++ b/searchlib/src/vespa/searchlib/queryeval/nearest_neighbor_blueprint.cpp
@@ -52,13 +52,15 @@ struct ConvertCellsSelector
NearestNeighborBlueprint::NearestNeighborBlueprint(const queryeval::FieldSpec& field,
const tensor::DenseTensorAttribute& attr_tensor,
std::unique_ptr<Value> query_tensor,
- uint32_t target_num_hits, bool approximate, uint32_t explore_additional_hits, double brute_force_limit)
+ uint32_t target_num_hits, bool approximate, uint32_t explore_additional_hits,
+ double distance_threshold, double brute_force_limit)
: ComplexLeafBlueprint(field),
_attr_tensor(attr_tensor),
_query_tensor(std::move(query_tensor)),
_target_num_hits(target_num_hits),
_approximate(approximate),
_explore_additional_hits(explore_additional_hits),
+ _distance_threshold(std::numeric_limits<double>::max()),
_brute_force_limit(brute_force_limit),
_fallback_dist_fun(),
_distance_heap(target_num_hits),
@@ -72,9 +74,15 @@ NearestNeighborBlueprint::NearestNeighborBlueprint(const queryeval::FieldSpec& f
fixup_fun(_query_tensor, _attr_tensor.getTensorType());
_fallback_dist_fun = search::tensor::make_distance_function(_attr_tensor.getConfig().distance_metric(), rct);
_dist_fun = _fallback_dist_fun.get();
+ assert(_dist_fun);
auto nns_index = _attr_tensor.nearest_neighbor_index();
if (nns_index) {
_dist_fun = nns_index->distance_function();
+ assert(_dist_fun);
+ }
+ if (distance_threshold < std::numeric_limits<double>::max()) {
+ _distance_threshold = _dist_fun->convert_threshold(distance_threshold);
+ _distance_heap.set_distance_threshold(_distance_threshold);
}
uint32_t est_hits = _attr_tensor.getNumDocs();
setEstimate(HitEstimate(est_hits, false));
@@ -127,9 +135,9 @@ NearestNeighborBlueprint::perform_top_k()
uint32_t k = _target_num_hits;
if (_global_filter->has_filter()) {
auto filter = _global_filter->filter();
- _found_hits = nns_index->find_top_k_with_filter(k, lhs, *filter, k + _explore_additional_hits);
+ _found_hits = nns_index->find_top_k_with_filter(k, lhs, *filter, k + _explore_additional_hits, _distance_threshold);
} else {
- _found_hits = nns_index->find_top_k(k, lhs, k + _explore_additional_hits);
+ _found_hits = nns_index->find_top_k(k, lhs, k + _explore_additional_hits, _distance_threshold);
}
}
}
diff --git a/searchlib/src/vespa/searchlib/queryeval/nearest_neighbor_blueprint.h b/searchlib/src/vespa/searchlib/queryeval/nearest_neighbor_blueprint.h
index a8a0ff19246..aad43c923a2 100644
--- a/searchlib/src/vespa/searchlib/queryeval/nearest_neighbor_blueprint.h
+++ b/searchlib/src/vespa/searchlib/queryeval/nearest_neighbor_blueprint.h
@@ -24,6 +24,7 @@ private:
uint32_t _target_num_hits;
bool _approximate;
uint32_t _explore_additional_hits;
+ double _distance_threshold;
double _brute_force_limit;
search::tensor::DistanceFunction::UP _fallback_dist_fun;
const search::tensor::DistanceFunction *_dist_fun;
@@ -36,7 +37,9 @@ public:
NearestNeighborBlueprint(const queryeval::FieldSpec& field,
const tensor::DenseTensorAttribute& attr_tensor,
std::unique_ptr<vespalib::eval::Value> query_tensor,
- uint32_t target_num_hits, bool approximate, uint32_t explore_additional_hits, double brute_force_limit);
+ uint32_t target_num_hits, bool approximate, uint32_t explore_additional_hits,
+ double distance_threshold,
+ double brute_force_limit);
NearestNeighborBlueprint(const NearestNeighborBlueprint&) = delete;
NearestNeighborBlueprint& operator=(const NearestNeighborBlueprint&) = delete;
~NearestNeighborBlueprint();
@@ -45,6 +48,7 @@ public:
uint32_t get_target_num_hits() const { return _target_num_hits; }
void set_global_filter(const GlobalFilter &global_filter) override;
bool may_approximate() const { return _approximate; }
+ double get_distance_threshold() const { return _distance_threshold; }
std::unique_ptr<SearchIterator> createLeafSearch(const search::fef::TermFieldMatchDataArray& tfmda,
bool strict) const override;
diff --git a/searchlib/src/vespa/searchlib/queryeval/nearest_neighbor_distance_heap.h b/searchlib/src/vespa/searchlib/queryeval/nearest_neighbor_distance_heap.h
index 3937dfba2ca..b7bdffd31c1 100644
--- a/searchlib/src/vespa/searchlib/queryeval/nearest_neighbor_distance_heap.h
+++ b/searchlib/src/vespa/searchlib/queryeval/nearest_neighbor_distance_heap.h
@@ -15,15 +15,22 @@ class NearestNeighborDistanceHeap {
private:
std::mutex _lock;
size_t _size;
+ double _distance_threshold;
vespalib::PriorityQueue<double, std::greater<double>> _priQ;
public:
- explicit NearestNeighborDistanceHeap(size_t maxSize) : _size(maxSize), _priQ() {
+ explicit NearestNeighborDistanceHeap(size_t maxSize)
+ : _size(maxSize), _distance_threshold(std::numeric_limits<double>::max()),
+ _priQ()
+ {
_priQ.reserve(maxSize);
}
+ void set_distance_threshold(double distance_threshold) {
+ _distance_threshold = distance_threshold;
+ }
double distanceLimit() {
std::lock_guard<std::mutex> guard(_lock);
if (_priQ.size() < _size) {
- return std::numeric_limits<double>::max();
+ return _distance_threshold;
}
return _priQ.front();
}
diff --git a/searchlib/src/vespa/searchlib/tensor/hnsw_index.cpp b/searchlib/src/vespa/searchlib/tensor/hnsw_index.cpp
index 6488b525b7c..44b2ff2b7f1 100644
--- a/searchlib/src/vespa/searchlib/tensor/hnsw_index.cpp
+++ b/searchlib/src/vespa/searchlib/tensor/hnsw_index.cpp
@@ -281,6 +281,7 @@ HnswIndex::HnswIndex(const DocVectorAccess& vectors, DistanceFunction::UP distan
_level_generator(std::move(level_generator)),
_cfg(cfg)
{
+ assert(_distance_func);
}
HnswIndex::~HnswIndex() = default;
@@ -534,7 +535,8 @@ struct NeighborsByDocId {
std::vector<NearestNeighborIndex::Neighbor>
HnswIndex::top_k_by_docid(uint32_t k, TypedCells vector,
- const BitVector *filter, uint32_t explore_k) const
+ const BitVector *filter, uint32_t explore_k,
+ double distance_threshold) const
{
std::vector<Neighbor> result;
FurthestPriQ candidates = top_k_candidates(vector, std::max(k, explore_k), filter);
@@ -543,6 +545,7 @@ HnswIndex::top_k_by_docid(uint32_t k, TypedCells vector,
}
result.reserve(candidates.size());
for (const HnswCandidate & hit : candidates.peek()) {
+ if (hit.distance > distance_threshold) continue;
result.emplace_back(hit.docid, hit.distance);
}
std::sort(result.begin(), result.end(), NeighborsByDocId());
@@ -550,16 +553,18 @@ HnswIndex::top_k_by_docid(uint32_t k, TypedCells vector,
}
std::vector<NearestNeighborIndex::Neighbor>
-HnswIndex::find_top_k(uint32_t k, TypedCells vector, uint32_t explore_k) const
+HnswIndex::find_top_k(uint32_t k, TypedCells vector, uint32_t explore_k,
+ double distance_threshold) const
{
- return top_k_by_docid(k, vector, nullptr, explore_k);
+ return top_k_by_docid(k, vector, nullptr, explore_k, distance_threshold);
}
std::vector<NearestNeighborIndex::Neighbor>
HnswIndex::find_top_k_with_filter(uint32_t k, TypedCells vector,
- const BitVector &filter, uint32_t explore_k) const
+ const BitVector &filter, uint32_t explore_k,
+ double distance_threshold) const
{
- return top_k_by_docid(k, vector, &filter, explore_k);
+ return top_k_by_docid(k, vector, &filter, explore_k, distance_threshold);
}
FurthestPriQ
diff --git a/searchlib/src/vespa/searchlib/tensor/hnsw_index.h b/searchlib/src/vespa/searchlib/tensor/hnsw_index.h
index c07a0642b2e..5bd9d17adc3 100644
--- a/searchlib/src/vespa/searchlib/tensor/hnsw_index.h
+++ b/searchlib/src/vespa/searchlib/tensor/hnsw_index.h
@@ -123,7 +123,8 @@ protected:
void search_layer(const TypedCells& input, uint32_t neighbors_to_find, FurthestPriQ& found_neighbors,
uint32_t level, const search::BitVector *filter = nullptr) const;
std::vector<Neighbor> top_k_by_docid(uint32_t k, TypedCells vector,
- const BitVector *filter, uint32_t explore_k) const;
+ const BitVector *filter, uint32_t explore_k,
+ double distance_threshold) const;
struct PreparedAddDoc : public PrepareResult {
using ReadGuard = vespalib::GenerationHandler::Guard;
@@ -166,9 +167,11 @@ public:
std::unique_ptr<NearestNeighborIndexSaver> make_saver() const override;
bool load(const fileutil::LoadedBuffer& buf) override;
- std::vector<Neighbor> find_top_k(uint32_t k, TypedCells vector, uint32_t explore_k) const override;
+ std::vector<Neighbor> find_top_k(uint32_t k, TypedCells vector, uint32_t explore_k,
+ double distance_threshold) const override;
std::vector<Neighbor> find_top_k_with_filter(uint32_t k, TypedCells vector,
- const BitVector &filter, uint32_t explore_k) const override;
+ const BitVector &filter, uint32_t explore_k,
+ double distance_threshold) const override;
const DistanceFunction *distance_function() const override { return _distance_func.get(); }
FurthestPriQ top_k_candidates(const TypedCells &vector, uint32_t k, const BitVector *filter) const;
diff --git a/searchlib/src/vespa/searchlib/tensor/nearest_neighbor_index.h b/searchlib/src/vespa/searchlib/tensor/nearest_neighbor_index.h
index c14da0d058f..fd37cf80720 100644
--- a/searchlib/src/vespa/searchlib/tensor/nearest_neighbor_index.h
+++ b/searchlib/src/vespa/searchlib/tensor/nearest_neighbor_index.h
@@ -73,13 +73,15 @@ public:
virtual std::vector<Neighbor> find_top_k(uint32_t k,
vespalib::eval::TypedCells vector,
- uint32_t explore_k) const = 0;
+ uint32_t explore_k,
+ double distance_threshold) const = 0;
// only return neighbors where the corresponding filter bit is set
virtual std::vector<Neighbor> find_top_k_with_filter(uint32_t k,
vespalib::eval::TypedCells vector,
const BitVector &filter,
- uint32_t explore_k) const = 0;
+ uint32_t explore_k,
+ double distance_threshold) const = 0;
virtual const DistanceFunction *distance_function() const = 0;
};