aboutsummaryrefslogtreecommitdiffstats
path: root/streamingvisitors/src/vespa/vsm/searcher/nearest_neighbor_field_searcher.cpp
diff options
context:
space:
mode:
authorGeir Storli <geirst@yahooinc.com>2023-04-25 16:20:45 +0000
committerGeir Storli <geirst@yahooinc.com>2023-04-25 16:20:45 +0000
commit1d6fcfb5c5b7399cd33c32ffea30fd9208ec000b (patch)
tree4de5f00fa4e84e779d9ae01b766a6a0a6a61416f /streamingvisitors/src/vespa/vsm/searcher/nearest_neighbor_field_searcher.cpp
parentf3ac0e360e47778eb51e3619825f09e52d3b6082 (diff)
Use targetHits in nearestNeighbor streaming searcher.
A distance heap is used to limit the number of produced document matches.
Diffstat (limited to 'streamingvisitors/src/vespa/vsm/searcher/nearest_neighbor_field_searcher.cpp')
-rw-r--r--streamingvisitors/src/vespa/vsm/searcher/nearest_neighbor_field_searcher.cpp21
1 files changed, 15 insertions, 6 deletions
diff --git a/streamingvisitors/src/vespa/vsm/searcher/nearest_neighbor_field_searcher.cpp b/streamingvisitors/src/vespa/vsm/searcher/nearest_neighbor_field_searcher.cpp
index f064760e55d..db4ee12438e 100644
--- a/streamingvisitors/src/vespa/vsm/searcher/nearest_neighbor_field_searcher.cpp
+++ b/streamingvisitors/src/vespa/vsm/searcher/nearest_neighbor_field_searcher.cpp
@@ -48,8 +48,17 @@ NearestNeighborFieldSearcher::NodeAndCalc::NodeAndCalc(search::streaming::Neares
std::unique_ptr<search::tensor::DistanceCalculator> calc_in)
: node(node_in),
calc(std::move(calc_in)),
- distance_threshold(calc->function().convert_threshold(node->get_distance_threshold()))
+ heap(node->get_target_hits())
{
+ node->set_raw_score_calc(this);
+ heap.set_distance_threshold(calc->function().convert_threshold(node->get_distance_threshold()));
+}
+
+double
+NearestNeighborFieldSearcher::NodeAndCalc::to_raw_score(double distance)
+{
+ heap.used(distance);
+ return calc->function().to_rawscore(distance);
}
NearestNeighborFieldSearcher::NearestNeighborFieldSearcher(FieldIdT fid,
@@ -100,7 +109,7 @@ NearestNeighborFieldSearcher::prepare(search::streaming::QueryTermList& qtl,
}
try {
auto calc = DistanceCalculator::make_with_validation(*_attr, *tensor_value);
- _calcs.emplace_back(nn_term, std::move(calc));
+ _calcs.push_back(std::make_unique<NodeAndCalc>(nn_term, std::move(calc)));
} catch (const vespalib::IllegalArgumentException& ex) {
vespalib::Issue::report("Could not create DistanceCalculator for NearestNeighborQueryNode(%s, %s): %s",
nn_term->index().c_str(), nn_term->get_query_tensor_name().c_str(), ex.what());
@@ -116,10 +125,10 @@ NearestNeighborFieldSearcher::onValue(const document::FieldValue& fv)
if (tfv && tfv->getAsTensorPtr()) {
_attr->add(*tfv->getAsTensorPtr(), 1);
for (auto& elem : _calcs) {
- double distance = elem.calc->calc_with_limit(scratch_docid, elem.distance_threshold);
- if (distance <= elem.distance_threshold) {
- double score = elem.calc->function().to_rawscore(distance);
- elem.node->set_raw_score(score);
+ double distance_limit = elem->heap.distanceLimit();
+ double distance = elem->calc->calc_with_limit(scratch_docid, distance_limit);
+ if (distance <= distance_limit) {
+ elem->node->set_distance(distance);
}
}
}