aboutsummaryrefslogtreecommitdiffstats
path: root/streamingvisitors
diff options
context:
space:
mode:
authorGeir Storli <geirst@yahooinc.com>2023-04-25 16:20:45 +0000
committerGeir Storli <geirst@yahooinc.com>2023-04-25 16:20:45 +0000
commit1d6fcfb5c5b7399cd33c32ffea30fd9208ec000b (patch)
tree4de5f00fa4e84e779d9ae01b766a6a0a6a61416f /streamingvisitors
parentf3ac0e360e47778eb51e3619825f09e52d3b6082 (diff)
Use targetHits in nearestNeighbor streaming searcher.
A distance heap is used to limit the number of produced document matches.
Diffstat (limited to 'streamingvisitors')
-rw-r--r--streamingvisitors/src/tests/nearest_neighbor_field_searcher/nearest_neighbor_field_searcher_test.cpp61
-rw-r--r--streamingvisitors/src/tests/rank_processor/rank_processor_test.cpp12
-rw-r--r--streamingvisitors/src/vespa/searchvisitor/rankprocessor.cpp2
-rw-r--r--streamingvisitors/src/vespa/vsm/searcher/nearest_neighbor_field_searcher.cpp21
-rw-r--r--streamingvisitors/src/vespa/vsm/searcher/nearest_neighbor_field_searcher.h13
5 files changed, 83 insertions, 26 deletions
diff --git a/streamingvisitors/src/tests/nearest_neighbor_field_searcher/nearest_neighbor_field_searcher_test.cpp b/streamingvisitors/src/tests/nearest_neighbor_field_searcher/nearest_neighbor_field_searcher_test.cpp
index 43c77398be8..b64d477fd4c 100644
--- a/streamingvisitors/src/tests/nearest_neighbor_field_searcher/nearest_neighbor_field_searcher_test.cpp
+++ b/streamingvisitors/src/tests/nearest_neighbor_field_searcher/nearest_neighbor_field_searcher_test.cpp
@@ -31,9 +31,11 @@ struct MockQuery {
std::vector<std::unique_ptr<NearestNeighborQueryNode>> nodes;
QueryTermList term_list;
MockQuery& add(const vespalib::string& query_tensor_name,
+ uint32_t target_hits,
double distance_threshold) {
std::unique_ptr<QueryNodeResultBase> base;
- auto node = std::make_unique<NearestNeighborQueryNode>(std::move(base), query_tensor_name, "my_tensor_field", 7, search::query::Weight(11), distance_threshold);
+ auto node = std::make_unique<NearestNeighborQueryNode>(std::move(base), query_tensor_name, "my_tensor_field",
+ target_hits, distance_threshold, 7, search::query::Weight(100));
nodes.push_back(std::move(node));
term_list.push_back(nodes.back().get());
return *this;
@@ -90,34 +92,71 @@ public:
query.reset();
searcher.onValue(fv);
}
+ void expect_match(const vespalib::string& spec_expr, double exp_square_distance, const NearestNeighborQueryNode& node) {
+ match(spec_expr);
+ expect_match(exp_square_distance, node);
+ }
void expect_match(double exp_square_distance, const NearestNeighborQueryNode& node) {
double exp_raw_score = dist_func.to_rawscore(exp_square_distance);
EXPECT_TRUE(node.evaluate());
+ EXPECT_DOUBLE_EQ(exp_square_distance, node.get_distance().value());
EXPECT_DOUBLE_EQ(exp_raw_score, node.get_raw_score().value());
}
+ void expect_not_match(const vespalib::string& spec_expr, const NearestNeighborQueryNode& node) {
+ match(spec_expr);
+ EXPECT_FALSE(node.evaluate());
+ }
};
-TEST_F(NearestNeighborSearcherTest, raw_score_calculated_with_distance_threshold)
+TEST_F(NearestNeighborSearcherTest, distance_heap_keeps_the_best_target_hits)
{
- query.add("qt1", 3);
+ query.add("qt1", 2, 100.0);
+ const auto& node = query.get(0);
set_query_tensor("qt1", "tensor(x[2]):[1,3]");
prepare();
- match("tensor(x[2]):[1,5]");
- expect_match((5-3)*(5-3), query.get(0));
+ expect_match("tensor(x[2]):[1,7]", (7-3)*(7-3), node);
+ expect_match("tensor(x[2]):[1,9]", (9-3)*(9-3), node);
- match("tensor(x[2]):[1,6]");
- expect_match((6-3)*(6-3), query.get(0));
+ // The distance limit is now (9-3)*(9-3) = 36, so this is not good enough.
+ expect_not_match("tensor(x[2]):[1,10]", node);
+
+ expect_match("tensor(x[2]):[1,5]", (5-3)*(5-3), node);
+
+ // The distance limit is now (7-3)*(7-3) = 16, so this is not good enough.
+ expect_not_match("tensor(x[2]):[1,8]", node);
+
+ // This is not considered a document match as get_raw_score() is not called,
+ // and the distance heap is not updated.
+ match("tensor(x[2]):[1,4]");
+ EXPECT_EQ(1, node.get_distance().value());
+ EXPECT_TRUE(node.evaluate());
+
+ // The distance limit is still (7-3)*(7-3) = 16, so this is in fact good enough.
+ expect_match("tensor(x[2]):[1,6]", (6-3)*(6-3), node);
+
+ // The distance limit is (6-3)*(6-3) = 4, and a similar distance is a match.
+ expect_match("tensor(x[2]):[1,6]", (6-3)*(6-3), node);
+}
+
+TEST_F(NearestNeighborSearcherTest, raw_score_calculated_with_distance_threshold)
+{
+ query.add("qt1", 10, 3.0);
+ const auto& node = query.get(0);
+ set_query_tensor("qt1", "tensor(x[2]):[1,3]");
+ prepare();
+
+ expect_match("tensor(x[2]):[1,5]", (5-3)*(5-3), node);
+ expect_match("tensor(x[2]):[1,6]", (6-3)*(6-3), node);
- match("tensor(x[2]):[1,7]");
// This is not a match since ((7-3)*(7-3) = 16) is larger than the internal distance threshold of (3*3 = 9).
- EXPECT_FALSE(query.get(0).evaluate());
+ expect_not_match("tensor(x[2]):[1,7]", node);
}
TEST_F(NearestNeighborSearcherTest, raw_score_calculated_for_two_query_operators)
{
- query.add("qt1", 3);
- query.add("qt2", 4);
+ query.add("qt1", 10, 3.0);
+ query.add("qt2", 10, 4.0);
set_query_tensor("qt1", "tensor(x[2]):[1,3]");
set_query_tensor("qt2", "tensor(x[2]):[1,4]");
prepare();
diff --git a/streamingvisitors/src/tests/rank_processor/rank_processor_test.cpp b/streamingvisitors/src/tests/rank_processor/rank_processor_test.cpp
index 9f3f3d770e4..4d425d9dedd 100644
--- a/streamingvisitors/src/tests/rank_processor/rank_processor_test.cpp
+++ b/streamingvisitors/src/tests/rank_processor/rank_processor_test.cpp
@@ -55,6 +55,10 @@ RankProcessorTest::build_query(QueryBuilder<SimpleQueryNodeTypes> &builder)
_query_wrapper = std::make_unique<QueryWrapper>(*_query);
}
+class MockRawScoreCalculator : public search::streaming::NearestNeighborQueryNode::RawScoreCalculator {
+public:
+ double to_raw_score(double distance) override { return distance * 2; }
+};
TEST_F(RankProcessorTest, unpack_match_data_for_nearest_neighbor_query_node)
{
@@ -71,6 +75,8 @@ TEST_F(RankProcessorTest, unpack_match_data_for_nearest_neighbor_query_node)
EXPECT_EQ(1u, term_list.size());
auto node = dynamic_cast<NearestNeighborQueryNode*>(term_list.front().getTerm());
EXPECT_NE(nullptr, node);
+ MockRawScoreCalculator calc;
+ node->set_raw_score_calc(&calc);
auto& qtd = static_cast<QueryTermData &>(node->getQueryItem());
auto& td = qtd.getTermData();
constexpr TermFieldHandle handle = 27;
@@ -82,11 +88,11 @@ TEST_F(RankProcessorTest, unpack_match_data_for_nearest_neighbor_query_node)
EXPECT_EQ(invalid_id, tfmd->getDocId());
RankProcessor::unpack_match_data(1, *md, *_query_wrapper);
EXPECT_EQ(invalid_id, tfmd->getDocId());
- constexpr double raw_score = 1.5;
- node->set_raw_score(raw_score);
+ constexpr double distance = 1.5;
+ node->set_distance(distance);
RankProcessor::unpack_match_data(2, *md, *_query_wrapper);
EXPECT_EQ(2, tfmd->getDocId());
- EXPECT_EQ(raw_score, tfmd->getRawScore());
+ EXPECT_EQ(distance * 2, tfmd->getRawScore());
node->reset();
RankProcessor::unpack_match_data(3, *md, *_query_wrapper);
EXPECT_EQ(2, tfmd->getDocId());
diff --git a/streamingvisitors/src/vespa/searchvisitor/rankprocessor.cpp b/streamingvisitors/src/vespa/searchvisitor/rankprocessor.cpp
index ba97a708cc5..3751ba379d0 100644
--- a/streamingvisitors/src/vespa/searchvisitor/rankprocessor.cpp
+++ b/streamingvisitors/src/vespa/searchvisitor/rankprocessor.cpp
@@ -241,7 +241,7 @@ RankProcessor::unpack_match_data(uint32_t docid, MatchData &matchData, QueryWrap
for (QueryWrapper::Term & term: query.getTermList()) {
auto nn_node = term.getTerm()->as_nearest_neighbor_query_node();
if (nn_node != nullptr) {
- auto& raw_score = nn_node->get_raw_score();
+ auto raw_score = nn_node->get_raw_score();
if (raw_score.has_value()) {
auto& qtd = static_cast<QueryTermData &>(term.getTerm()->getQueryItem());
auto& td = qtd.getTermData();
diff --git a/streamingvisitors/src/vespa/vsm/searcher/nearest_neighbor_field_searcher.cpp b/streamingvisitors/src/vespa/vsm/searcher/nearest_neighbor_field_searcher.cpp
index f064760e55d..db4ee12438e 100644
--- a/streamingvisitors/src/vespa/vsm/searcher/nearest_neighbor_field_searcher.cpp
+++ b/streamingvisitors/src/vespa/vsm/searcher/nearest_neighbor_field_searcher.cpp
@@ -48,8 +48,17 @@ NearestNeighborFieldSearcher::NodeAndCalc::NodeAndCalc(search::streaming::Neares
std::unique_ptr<search::tensor::DistanceCalculator> calc_in)
: node(node_in),
calc(std::move(calc_in)),
- distance_threshold(calc->function().convert_threshold(node->get_distance_threshold()))
+ heap(node->get_target_hits())
{
+ node->set_raw_score_calc(this);
+ heap.set_distance_threshold(calc->function().convert_threshold(node->get_distance_threshold()));
+}
+
+double
+NearestNeighborFieldSearcher::NodeAndCalc::to_raw_score(double distance)
+{
+ heap.used(distance);
+ return calc->function().to_rawscore(distance);
}
NearestNeighborFieldSearcher::NearestNeighborFieldSearcher(FieldIdT fid,
@@ -100,7 +109,7 @@ NearestNeighborFieldSearcher::prepare(search::streaming::QueryTermList& qtl,
}
try {
auto calc = DistanceCalculator::make_with_validation(*_attr, *tensor_value);
- _calcs.emplace_back(nn_term, std::move(calc));
+ _calcs.push_back(std::make_unique<NodeAndCalc>(nn_term, std::move(calc)));
} catch (const vespalib::IllegalArgumentException& ex) {
vespalib::Issue::report("Could not create DistanceCalculator for NearestNeighborQueryNode(%s, %s): %s",
nn_term->index().c_str(), nn_term->get_query_tensor_name().c_str(), ex.what());
@@ -116,10 +125,10 @@ NearestNeighborFieldSearcher::onValue(const document::FieldValue& fv)
if (tfv && tfv->getAsTensorPtr()) {
_attr->add(*tfv->getAsTensorPtr(), 1);
for (auto& elem : _calcs) {
- double distance = elem.calc->calc_with_limit(scratch_docid, elem.distance_threshold);
- if (distance <= elem.distance_threshold) {
- double score = elem.calc->function().to_rawscore(distance);
- elem.node->set_raw_score(score);
+ double distance_limit = elem->heap.distanceLimit();
+ double distance = elem->calc->calc_with_limit(scratch_docid, distance_limit);
+ if (distance <= distance_limit) {
+ elem->node->set_distance(distance);
}
}
}
diff --git a/streamingvisitors/src/vespa/vsm/searcher/nearest_neighbor_field_searcher.h b/streamingvisitors/src/vespa/vsm/searcher/nearest_neighbor_field_searcher.h
index ba39b91c677..d5d751cd637 100644
--- a/streamingvisitors/src/vespa/vsm/searcher/nearest_neighbor_field_searcher.h
+++ b/streamingvisitors/src/vespa/vsm/searcher/nearest_neighbor_field_searcher.h
@@ -5,6 +5,8 @@
#include "fieldsearcher.h"
#include <vespa/eval/eval/value_type.h>
#include <vespa/searchcommon/attribute/distance_metric.h>
+#include <vespa/searchlib/query/streaming/nearest_neighbor_query_node.h>
+#include <vespa/searchlib/queryeval/nearest_neighbor_distance_heap.h>
#include <vespa/searchlib/tensor/distance_calculator.h>
#include <vespa/searchlib/tensor/tensor_ext_attribute.h>
@@ -14,8 +16,6 @@ namespace search::tensor {
class TensorExtAttribute;
}
-namespace search::streaming { class NearestNeighborQueryNode; }
-
namespace vsm {
/**
@@ -26,16 +26,19 @@ namespace vsm {
*/
class NearestNeighborFieldSearcher : public FieldSearcher {
private:
- struct NodeAndCalc {
+ class NodeAndCalc : search::streaming::NearestNeighborQueryNode::RawScoreCalculator {
+ public:
search::streaming::NearestNeighborQueryNode* node;
std::unique_ptr<search::tensor::DistanceCalculator> calc;
- double distance_threshold;
+ search::queryeval::NearestNeighborDistanceHeap heap;
NodeAndCalc(search::streaming::NearestNeighborQueryNode* node_in,
std::unique_ptr<search::tensor::DistanceCalculator> calc_in);
+
+ double to_raw_score(double distance) override;
};
search::attribute::DistanceMetric _metric;
std::unique_ptr<search::tensor::TensorExtAttribute> _attr;
- std::vector<NodeAndCalc> _calcs;
+ std::vector<std::unique_ptr<NodeAndCalc>> _calcs;
public:
NearestNeighborFieldSearcher(FieldIdT fid,