summaryrefslogtreecommitdiffstats
path: root/searchlib
diff options
context:
space:
mode:
authorGeir Storli <geirst@yahooinc.com>2023-04-25 16:20:45 +0000
committerGeir Storli <geirst@yahooinc.com>2023-04-25 16:20:45 +0000
commit1d6fcfb5c5b7399cd33c32ffea30fd9208ec000b (patch)
tree4de5f00fa4e84e779d9ae01b766a6a0a6a61416f /searchlib
parentf3ac0e360e47778eb51e3619825f09e52d3b6082 (diff)
Use targetHits in nearestNeighbor streaming searcher.
A distance heap is used to limit the number of produced document matches.
Diffstat (limited to 'searchlib')
-rw-r--r--searchlib/src/tests/query/streaming_query_test.cpp12
-rw-r--r--searchlib/src/vespa/searchlib/parsequery/stackdumpiterator.h2
-rw-r--r--searchlib/src/vespa/searchlib/query/streaming/nearest_neighbor_query_node.cpp28
-rw-r--r--searchlib/src/vespa/searchlib/query/streaming/nearest_neighbor_query_node.h34
-rw-r--r--searchlib/src/vespa/searchlib/query/streaming/querynode.cpp12
-rw-r--r--searchlib/src/vespa/searchlib/query/tree/stackdumpquerycreator.h6
6 files changed, 67 insertions, 27 deletions
diff --git a/searchlib/src/tests/query/streaming_query_test.cpp b/searchlib/src/tests/query/streaming_query_test.cpp
index 2c202d9131b..210f32af15e 100644
--- a/searchlib/src/tests/query/streaming_query_test.cpp
+++ b/searchlib/src/tests/query/streaming_query_test.cpp
@@ -814,7 +814,7 @@ TEST("test_nearest_neighbor_query_node")
constexpr uint32_t target_num_hits = 100;
constexpr bool allow_approximate = false;
constexpr uint32_t explore_additional_hits = 800;
- constexpr double raw_score = 0.5;
+ constexpr double distance = 0.5;
builder.add_nearest_neighbor_term("qtensor", "field", id, Weight(weight), target_num_hits, allow_approximate, explore_additional_hits, distance_threshold);
auto build_node = builder.build();
auto stack_dump = StackDumpCreator::create(*build_node);
@@ -830,14 +830,14 @@ TEST("test_nearest_neighbor_query_node")
EXPECT_EQUAL(id, static_cast<int32_t>(node->uniqueId()));
EXPECT_EQUAL(weight, node->weight().percent());
EXPECT_EQUAL(distance_threshold, node->get_distance_threshold());
- EXPECT_FALSE(node->get_raw_score().has_value());
+ EXPECT_FALSE(node->get_distance().has_value());
EXPECT_FALSE(node->evaluate());
- node->set_raw_score(raw_score);
- EXPECT_TRUE(node->get_raw_score().has_value());
- EXPECT_EQUAL(raw_score, node->get_raw_score().value());
+ node->set_distance(distance);
+ EXPECT_TRUE(node->get_distance().has_value());
+ EXPECT_EQUAL(distance, node->get_distance().value());
EXPECT_TRUE(node->evaluate());
node->reset();
- EXPECT_FALSE(node->get_raw_score().has_value());
+ EXPECT_FALSE(node->get_distance().has_value());
EXPECT_FALSE(node->evaluate());
}
diff --git a/searchlib/src/vespa/searchlib/parsequery/stackdumpiterator.h b/searchlib/src/vespa/searchlib/parsequery/stackdumpiterator.h
index 46b89fdfeb4..9bef389a278 100644
--- a/searchlib/src/vespa/searchlib/parsequery/stackdumpiterator.h
+++ b/searchlib/src/vespa/searchlib/parsequery/stackdumpiterator.h
@@ -109,7 +109,7 @@ public:
uint32_t getArity() const { return _currArity; }
uint32_t getNearDistance() const { return _extraIntArg1; }
- uint32_t getTargetNumHits() const { return _extraIntArg1; }
+ uint32_t getTargetHits() const { return _extraIntArg1; }
double getDistanceThreshold() const { return _extraDoubleArg4; }
double getScoreThreshold() const { return _extraDoubleArg4; }
double getThresholdBoostFactor() const { return _extraDoubleArg5; }
diff --git a/searchlib/src/vespa/searchlib/query/streaming/nearest_neighbor_query_node.cpp b/searchlib/src/vespa/searchlib/query/streaming/nearest_neighbor_query_node.cpp
index d1c37cd6dcd..b2d8a0ee4be 100644
--- a/searchlib/src/vespa/searchlib/query/streaming/nearest_neighbor_query_node.cpp
+++ b/searchlib/src/vespa/searchlib/query/streaming/nearest_neighbor_query_node.cpp
@@ -1,15 +1,21 @@
// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
#include "nearest_neighbor_query_node.h"
+#include <cassert>
namespace search::streaming {
-NearestNeighborQueryNode::NearestNeighborQueryNode(std::unique_ptr<QueryNodeResultBase> resultBase, const string& term, const string& index, int32_t id, search::query::Weight weight, double distance_threshold)
- : QueryTerm(std::move(resultBase), term, index, Type::NEAREST_NEIGHBOR),
+NearestNeighborQueryNode::NearestNeighborQueryNode(std::unique_ptr<QueryNodeResultBase> resultBase,
+ const string& query_tensor_name, const string& field_name,
+ uint32_t target_hits, double distance_threshold,
+ int32_t unique_id, search::query::Weight weight)
+ : QueryTerm(std::move(resultBase), query_tensor_name, field_name, Type::NEAREST_NEIGHBOR),
+ _target_hits(target_hits),
_distance_threshold(distance_threshold),
- _raw_score()
+ _distance(),
+ _calc()
{
- setUniqueId(id);
+ setUniqueId(unique_id);
setWeight(weight);
}
@@ -18,13 +24,13 @@ NearestNeighborQueryNode::~NearestNeighborQueryNode() = default;
bool
NearestNeighborQueryNode::evaluate() const
{
- return _raw_score.has_value();
+ return _distance.has_value();
}
void
NearestNeighborQueryNode::reset()
{
- _raw_score.reset();
+ _distance.reset();
}
NearestNeighborQueryNode*
@@ -33,4 +39,14 @@ NearestNeighborQueryNode::as_nearest_neighbor_query_node() noexcept
return this;
}
+std::optional<double>
+NearestNeighborQueryNode::get_raw_score() const
+{
+ if (_distance.has_value()) {
+ assert(_calc != nullptr);
+ return _calc->to_raw_score(_distance.value());
+ }
+ return std::nullopt;
+}
+
}
diff --git a/searchlib/src/vespa/searchlib/query/streaming/nearest_neighbor_query_node.h b/searchlib/src/vespa/searchlib/query/streaming/nearest_neighbor_query_node.h
index 0beb130c53d..c66364b0c52 100644
--- a/searchlib/src/vespa/searchlib/query/streaming/nearest_neighbor_query_node.h
+++ b/searchlib/src/vespa/searchlib/query/streaming/nearest_neighbor_query_node.h
@@ -8,16 +8,34 @@
namespace search::streaming {
/*
- * Nearest neighbor query node.
+ * Nearest neighbor query node for streaming search.
*/
class NearestNeighborQueryNode: public QueryTerm {
+public:
+ class RawScoreCalculator {
+ public:
+ virtual ~RawScoreCalculator() = default;
+ /**
+ * Convert the given distance to a raw score.
+ *
+ * This is used during unpacking, and also signals that the entire document was a match.
+ */
+ virtual double to_raw_score(double distance) = 0;
+ };
+
private:
+ uint32_t _target_hits;
double _distance_threshold;
- // When this value is set it also indicates a match
- std::optional<double> _raw_score;
+ // When this value is set it also indicates a match for this query node.
+ std::optional<double> _distance;
+ RawScoreCalculator* _calc;
+
public:
- NearestNeighborQueryNode(std::unique_ptr<QueryNodeResultBase> resultBase, const string& term, const string& index, int32_t id, search::query::Weight weight, double distance_threshold);
+ NearestNeighborQueryNode(std::unique_ptr<QueryNodeResultBase> resultBase,
+ const string& query_tensor_name, const string& field_name,
+ uint32_t target_hits, double distance_threshold,
+ int32_t unique_id, search::query::Weight weight);
NearestNeighborQueryNode(const NearestNeighborQueryNode &) = delete;
NearestNeighborQueryNode & operator = (const NearestNeighborQueryNode &) = delete;
NearestNeighborQueryNode(NearestNeighborQueryNode &&) = delete;
@@ -27,9 +45,13 @@ public:
void reset() override;
NearestNeighborQueryNode* as_nearest_neighbor_query_node() noexcept override;
const vespalib::string& get_query_tensor_name() const { return getTermString(); }
+ uint32_t get_target_hits() const { return _target_hits; }
double get_distance_threshold() const { return _distance_threshold; }
- void set_raw_score(double value) { _raw_score = value; }
- const std::optional<double>& get_raw_score() const noexcept { return _raw_score; }
+ void set_raw_score_calc(RawScoreCalculator* calc_in) { _calc = calc_in; }
+ void set_distance(double value) { _distance = value; }
+ const std::optional<double>& get_distance() const { return _distance; }
+ // This is used during unpacking, and also signals to the RawScoreCalculator that the entire document was a match.
+ std::optional<double> get_raw_score() const;
};
}
diff --git a/searchlib/src/vespa/searchlib/query/streaming/querynode.cpp b/searchlib/src/vespa/searchlib/query/streaming/querynode.cpp
index 226cb92c894..84344831cbc 100644
--- a/searchlib/src/vespa/searchlib/query/streaming/querynode.cpp
+++ b/searchlib/src/vespa/searchlib/query/streaming/querynode.cpp
@@ -200,15 +200,17 @@ QueryNode::build_nearest_neighbor_query_node(const QueryNodeResultFactory& facto
{
vespalib::stringref query_tensor_name = query_rep.getTerm();
vespalib::stringref field_name = query_rep.getIndexName();
- int32_t id = query_rep.getUniqueId();
- search::query::Weight weight = query_rep.GetWeight();
+ int32_t unique_id = query_rep.getUniqueId();
+ auto weight = query_rep.GetWeight();
+ uint32_t target_hits = query_rep.getTargetHits();
double distance_threshold = query_rep.getDistanceThreshold();
return std::make_unique<NearestNeighborQueryNode>(factory.create(),
query_tensor_name,
field_name,
- id,
- weight,
- distance_threshold);
+ target_hits,
+ distance_threshold,
+ unique_id,
+ weight);
}
}
diff --git a/searchlib/src/vespa/searchlib/query/tree/stackdumpquerycreator.h b/searchlib/src/vespa/searchlib/query/tree/stackdumpquerycreator.h
index 90bd87979c7..a552a650704 100644
--- a/searchlib/src/vespa/searchlib/query/tree/stackdumpquerycreator.h
+++ b/searchlib/src/vespa/searchlib/query/tree/stackdumpquerycreator.h
@@ -89,7 +89,7 @@ private:
pureTermView = view;
} else if (type == ParseItem::ITEM_WEAK_AND) {
vespalib::stringref view = queryStack.getIndexName();
- uint32_t targetNumHits = queryStack.getTargetNumHits();
+ uint32_t targetNumHits = queryStack.getTargetHits();
builder.addWeakAnd(arity, targetNumHits, view);
pureTermView = view;
} else if (type == ParseItem::ITEM_EQUIV) {
@@ -134,7 +134,7 @@ private:
vespalib::stringref view = queryStack.getIndexName();
int32_t id = queryStack.getUniqueId();
Weight weight = queryStack.GetWeight();
- uint32_t targetNumHits = queryStack.getTargetNumHits();
+ uint32_t targetNumHits = queryStack.getTargetHits();
double scoreThreshold = queryStack.getScoreThreshold();
double thresholdBoostFactor = queryStack.getThresholdBoostFactor();
auto & wand = builder.addWandTerm(arity, view, id, weight, targetNumHits, scoreThreshold, thresholdBoostFactor);
@@ -146,7 +146,7 @@ private:
} else if (type == ParseItem::ITEM_NEAREST_NEIGHBOR) {
vespalib::stringref query_tensor_name = queryStack.getTerm();
vespalib::stringref field_name = queryStack.getIndexName();
- uint32_t target_num_hits = queryStack.getTargetNumHits();
+ uint32_t target_num_hits = queryStack.getTargetHits();
int32_t id = queryStack.getUniqueId();
Weight weight = queryStack.GetWeight();
bool allow_approximate = queryStack.getAllowApproximate();