summaryrefslogtreecommitdiffstats
path: root/searchlib
diff options
context:
space:
mode:
authorTor Egge <tegge@vespa.ai>2024-01-12 16:20:03 +0100
committerGitHub <noreply@github.com>2024-01-12 16:20:03 +0100
commit0213cea0df0706a1f1a70c44e8c2b5906745a6ab (patch)
tree1c5014b12170db844acbcf7ecc86525db8cc0e62 /searchlib
parentf2ded8dd8ebfc2c567fffea98b5e750ab1ed0da1 (diff)
parent97c388081bdf51a6d086c1e51e3919d8eb427117 (diff)
Merge pull request #29879 from vespa-engine/toregge/dot-product-raw-score-for-streaming-search-wand
Calculate raw score for streaming search wand.
Diffstat (limited to 'searchlib')
-rw-r--r--searchlib/src/tests/query/streaming_query_test.cpp63
-rw-r--r--searchlib/src/vespa/searchlib/query/streaming/CMakeLists.txt1
-rw-r--r--searchlib/src/vespa/searchlib/query/streaming/dot_product_term.cpp31
-rw-r--r--searchlib/src/vespa/searchlib/query/streaming/dot_product_term.h6
-rw-r--r--searchlib/src/vespa/searchlib/query/streaming/query.h2
-rw-r--r--searchlib/src/vespa/searchlib/query/streaming/querynode.cpp22
-rw-r--r--searchlib/src/vespa/searchlib/query/streaming/querynode.h1
-rw-r--r--searchlib/src/vespa/searchlib/query/streaming/wand_term.cpp44
-rw-r--r--searchlib/src/vespa/searchlib/query/streaming/wand_term.h22
9 files changed, 177 insertions, 15 deletions
diff --git a/searchlib/src/tests/query/streaming_query_test.cpp b/searchlib/src/tests/query/streaming_query_test.cpp
index c4ef2028123..7c4b7555158 100644
--- a/searchlib/src/tests/query/streaming_query_test.cpp
+++ b/searchlib/src/tests/query/streaming_query_test.cpp
@@ -6,6 +6,7 @@
#include <vespa/searchlib/query/streaming/in_term.h>
#include <vespa/searchlib/query/streaming/query.h>
#include <vespa/searchlib/query/streaming/nearest_neighbor_query_node.h>
+#include <vespa/searchlib/query/streaming/wand_term.h>
#include <vespa/searchlib/query/tree/querybuilder.h>
#include <vespa/searchlib/query/tree/simplequery.h>
#include <vespa/searchlib/query/tree/stackdumpcreator.h>
@@ -957,6 +958,68 @@ TEST(StreamingQueryTest, dot_product_term)
EXPECT_EQ(-17 * 27 + 9 * 2, tmd1->getRawScore());
}
+namespace {
+
+constexpr double exp_wand_score_field_12 = 13 * 27 + 4 * 2;
+constexpr double exp_wand_score_field_11 = 17 * 27 + 9 * 2;
+
+void
+check_wand_term(double limit, const vespalib::string& label)
+{
+ SCOPED_TRACE(label);
+ search::streaming::WandTerm term({}, "index", 2);
+ term.add_term(std::make_unique<QueryTerm>(std::unique_ptr<QueryNodeResultBase>(), "7", "", QueryTermSimple::Type::WORD));
+ term.get_terms().back()->setWeight(Weight(27));
+ term.add_term(std::make_unique<QueryTerm>(std::unique_ptr<QueryNodeResultBase>(), "9", "", QueryTermSimple::Type::WORD));
+ term.get_terms().back()->setWeight(Weight(2));
+ EXPECT_EQ(2, term.get_terms().size());
+ term.set_score_threshold(limit);
+ SimpleTermData td;
+ /*
+ * Search in fields 10, 11 and 12 (cf. fieldset in schema).
+ * Fields 11 and 12 have content for doc containing the keys.
+ * Fields 10 and 12 have valid handles and can be used for ranking.
+ * Field 11 does not have a valid handle, thus no associated match data.
+ */
+ td.addField(10);
+ td.addField(11);
+ td.addField(12);
+ td.lookupField(10)->setHandle(0);
+ td.lookupField(12)->setHandle(1);
+ EXPECT_FALSE(term.evaluate());
+ auto& q0 = *term.get_terms()[0];
+ q0.add(0, 11, 0, 17);
+ q0.add(0, 12, 0, 13);
+ auto& q1 = *term.get_terms()[1];
+ q1.add(0, 11, 0, 9);
+ q1.add(0, 12, 0, 4);
+ EXPECT_EQ(limit < exp_wand_score_field_11, term.evaluate());
+ MatchData md(MatchData::params().numTermFields(2));
+ term.unpack_match_data(23, td, md);
+ auto tmd0 = md.resolveTermField(0);
+ EXPECT_NE(23, tmd0->getDocId());
+ auto tmd1 = md.resolveTermField(1);
+ if (limit < exp_wand_score_field_12) {
+ EXPECT_EQ(23, tmd1->getDocId());
+ EXPECT_EQ(exp_wand_score_field_12, tmd1->getRawScore());
+ } else {
+ EXPECT_NE(23, tmd1->getDocId());
+ }
+}
+
+}
+
+TEST(StreamingQueryTest, wand_term)
+{
+ check_wand_term(0.0, "no limit");
+ check_wand_term(exp_wand_score_field_12 - 1, "score above limit");
+ check_wand_term(exp_wand_score_field_12, "score at limit");
+ check_wand_term(exp_wand_score_field_12 + 1, "score below limit");
+ check_wand_term(exp_wand_score_field_11 - 1, "hidden score above limit");
+ check_wand_term(exp_wand_score_field_11, "hidden score at limit");
+ check_wand_term(exp_wand_score_field_11 + 1, "hidden score below limit");
+}
+
TEST(StreamingQueryTest, control_the_size_of_query_terms)
{
EXPECT_EQ(112u, sizeof(QueryTermSimple));
diff --git a/searchlib/src/vespa/searchlib/query/streaming/CMakeLists.txt b/searchlib/src/vespa/searchlib/query/streaming/CMakeLists.txt
index 0813292a9da..9b53407aff5 100644
--- a/searchlib/src/vespa/searchlib/query/streaming/CMakeLists.txt
+++ b/searchlib/src/vespa/searchlib/query/streaming/CMakeLists.txt
@@ -9,5 +9,6 @@ vespa_add_library(searchlib_query_streaming OBJECT
querynode.cpp
querynoderesultbase.cpp
queryterm.cpp
+ wand_term.cpp
DEPENDS
)
diff --git a/searchlib/src/vespa/searchlib/query/streaming/dot_product_term.cpp b/searchlib/src/vespa/searchlib/query/streaming/dot_product_term.cpp
index 9bb6d8c3342..1871bda564d 100644
--- a/searchlib/src/vespa/searchlib/query/streaming/dot_product_term.cpp
+++ b/searchlib/src/vespa/searchlib/query/streaming/dot_product_term.cpp
@@ -18,29 +18,44 @@ DotProductTerm::DotProductTerm(std::unique_ptr<QueryNodeResultBase> result_base,
DotProductTerm::~DotProductTerm() = default;
void
-DotProductTerm::unpack_match_data(uint32_t docid, const ITermData& td, MatchData& match_data)
+DotProductTerm::build_scores(Scores& scores) const
{
- vespalib::hash_map<uint32_t,double> scores;
HitList hl_store;
for (const auto& term : _terms) {
auto& hl = term->evaluateHits(hl_store);
for (auto& hit : hl) {
- scores[hit.context()] += term->weight().percent() * hit.weight();
+ scores[hit.context()] += ((int64_t)term->weight().percent()) * hit.weight();
}
}
+}
+
+void
+DotProductTerm::unpack_scores(Scores& scores, std::optional<double> score_threshold, uint32_t docid, const fef::ITermData& td, fef::MatchData& match_data)
+{
auto num_fields = td.numFields();
for (uint32_t field_idx = 0; field_idx < num_fields; ++field_idx) {
auto& tfd = td.field(field_idx);
auto field_id = tfd.getFieldId();
if (scores.contains(field_id)) {
- auto handle = tfd.getHandle();
- if (handle != fef::IllegalHandle) {
- auto tmd = match_data.resolveTermField(tfd.getHandle());
- tmd->setFieldId(field_id);
- tmd->setRawScore(docid, scores[field_id]);
+ auto score = scores[field_id];
+ if (!score_threshold.has_value() || score_threshold.value() < score) {
+ auto handle = tfd.getHandle();
+ if (handle != fef::IllegalHandle) {
+ auto tmd = match_data.resolveTermField(tfd.getHandle());
+ tmd->setFieldId(field_id);
+ tmd->setRawScore(docid, score);
+ }
}
}
}
}
+void
+DotProductTerm::unpack_match_data(uint32_t docid, const ITermData& td, MatchData& match_data)
+{
+ Scores scores;
+ build_scores(scores);
+ unpack_scores(scores, std::nullopt, docid, td, match_data);
+}
+
}
diff --git a/searchlib/src/vespa/searchlib/query/streaming/dot_product_term.h b/searchlib/src/vespa/searchlib/query/streaming/dot_product_term.h
index 77cac693781..3702bd4721c 100644
--- a/searchlib/src/vespa/searchlib/query/streaming/dot_product_term.h
+++ b/searchlib/src/vespa/searchlib/query/streaming/dot_product_term.h
@@ -3,6 +3,8 @@
#pragma once
#include "multi_term.h"
+#include <vespa/vespalib/stllike/hash_map.h>
+#include <optional>
namespace search::streaming {
@@ -10,6 +12,10 @@ namespace search::streaming {
* A dot product query term for streaming search.
*/
class DotProductTerm : public MultiTerm {
+protected:
+ using Scores = vespalib::hash_map<uint32_t,double>;
+ void build_scores(Scores& scores) const;
+ void unpack_scores(Scores& scores, std::optional<double> score_threshold, uint32_t docid, const fef::ITermData& td, fef::MatchData& match_data);
public:
DotProductTerm(std::unique_ptr<QueryNodeResultBase> result_base, const string& index, uint32_t num_terms);
~DotProductTerm() override;
diff --git a/searchlib/src/vespa/searchlib/query/streaming/query.h b/searchlib/src/vespa/searchlib/query/streaming/query.h
index 3904f743d26..8befa2fe7fa 100644
--- a/searchlib/src/vespa/searchlib/query/streaming/query.h
+++ b/searchlib/src/vespa/searchlib/query/streaming/query.h
@@ -90,8 +90,6 @@ public:
bool evaluate() const override;
bool isFlattenable(ParseItem::ItemType type) const override {
return (type == ParseItem::ITEM_OR) ||
- (type == ParseItem::ITEM_DOT_PRODUCT) ||
- (type == ParseItem::ITEM_WAND) ||
(type == ParseItem::ITEM_WEAK_AND);
}
};
diff --git a/searchlib/src/vespa/searchlib/query/streaming/querynode.cpp b/searchlib/src/vespa/searchlib/query/streaming/querynode.cpp
index 1e43c32a263..c24f41d16cf 100644
--- a/searchlib/src/vespa/searchlib/query/streaming/querynode.cpp
+++ b/searchlib/src/vespa/searchlib/query/streaming/querynode.cpp
@@ -5,6 +5,7 @@
#include <vespa/searchlib/parsequery/stackdumpiterator.h>
#include <vespa/searchlib/query/streaming/dot_product_term.h>
#include <vespa/searchlib/query/streaming/in_term.h>
+#include <vespa/searchlib/query/streaming/wand_term.h>
#include <vespa/searchlib/query/tree/term_vector.h>
#include <charconv>
#include <vespa/log/log.h>
@@ -40,7 +41,6 @@ QueryNode::Build(const QueryNode * parent, const QueryNodeResultFactory & factor
case ParseItem::ITEM_WEAK_AND:
case ParseItem::ITEM_EQUIV:
case ParseItem::ITEM_WEIGHTED_SET:
- case ParseItem::ITEM_WAND:
case ParseItem::ITEM_NOT:
case ParseItem::ITEM_PHRASE:
case ParseItem::ITEM_SAME_ELEMENT:
@@ -56,9 +56,7 @@ QueryNode::Build(const QueryNode * parent, const QueryNodeResultFactory & factor
}
if ((type == ParseItem::ITEM_WEAK_AND) ||
(type == ParseItem::ITEM_WEIGHTED_SET) ||
- (type == ParseItem::ITEM_DOT_PRODUCT) ||
- (type == ParseItem::ITEM_SAME_ELEMENT) ||
- (type == ParseItem::ITEM_WAND))
+ (type == ParseItem::ITEM_SAME_ELEMENT))
{
qn->setIndex(queryRep.getIndexName());
}
@@ -191,6 +189,9 @@ QueryNode::Build(const QueryNode * parent, const QueryNodeResultFactory & factor
case ParseItem::ITEM_DOT_PRODUCT:
qn = build_dot_product_term(factory, queryRep);
break;
+ case ParseItem::ITEM_WAND:
+ qn = build_wand_term(factory, queryRep);
+ break;
default:
skip_unknown(queryRep);
break;
@@ -251,13 +252,24 @@ QueryNode::populate_multi_term(Normalizing string_normalize_mode, MultiTerm& mt,
std::unique_ptr<QueryNode>
QueryNode::build_dot_product_term(const QueryNodeResultFactory& factory, SimpleQueryStackDumpIterator& queryRep)
{
- auto dp =std::make_unique<DotProductTerm>(factory.create(), queryRep.getIndexName(), queryRep.getArity());
+ auto dp = std::make_unique<DotProductTerm>(factory.create(), queryRep.getIndexName(), queryRep.getArity());
dp->setWeight(queryRep.GetWeight());
dp->setUniqueId(queryRep.getUniqueId());
populate_multi_term(factory.normalizing_mode(dp->index()), *dp, queryRep);
return dp;
}
+std::unique_ptr<QueryNode>
+QueryNode::build_wand_term(const QueryNodeResultFactory& factory, SimpleQueryStackDumpIterator& queryRep)
+{
+ auto wand = std::make_unique<WandTerm>(factory.create(), queryRep.getIndexName(), queryRep.getArity());
+ wand->setWeight(queryRep.GetWeight());
+ wand->setUniqueId(queryRep.getUniqueId());
+ wand->set_score_threshold(queryRep.getScoreThreshold());
+ populate_multi_term(factory.normalizing_mode(wand->index()), *wand, queryRep);
+ return wand;
+}
+
void
QueryNode::skip_unknown(SimpleQueryStackDumpIterator& queryRep)
{
diff --git a/searchlib/src/vespa/searchlib/query/streaming/querynode.h b/searchlib/src/vespa/searchlib/query/streaming/querynode.h
index 576d614e58b..a0561b2e52e 100644
--- a/searchlib/src/vespa/searchlib/query/streaming/querynode.h
+++ b/searchlib/src/vespa/searchlib/query/streaming/querynode.h
@@ -31,6 +31,7 @@ class QueryNode
static std::unique_ptr<QueryNode> build_nearest_neighbor_query_node(const QueryNodeResultFactory& factory, SimpleQueryStackDumpIterator& queryRep);
static void populate_multi_term(Normalizing string_normalize_mode, MultiTerm& mt, SimpleQueryStackDumpIterator& queryRep);
static std::unique_ptr<QueryNode> build_dot_product_term(const QueryNodeResultFactory& factory, SimpleQueryStackDumpIterator& queryRep);
+ static std::unique_ptr<QueryNode> build_wand_term(const QueryNodeResultFactory& factory, SimpleQueryStackDumpIterator& queryRep);
static void skip_unknown(SimpleQueryStackDumpIterator& queryRep);
public:
using UP = std::unique_ptr<QueryNode>;
diff --git a/searchlib/src/vespa/searchlib/query/streaming/wand_term.cpp b/searchlib/src/vespa/searchlib/query/streaming/wand_term.cpp
new file mode 100644
index 00000000000..a561adf5b42
--- /dev/null
+++ b/searchlib/src/vespa/searchlib/query/streaming/wand_term.cpp
@@ -0,0 +1,44 @@
+// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+
+#include "wand_term.h"
+#include <vespa/searchlib/fef/itermdata.h>
+#include <vespa/searchlib/fef/matchdata.h>
+
+using search::fef::ITermData;
+using search::fef::MatchData;
+
+namespace search::streaming {
+
+WandTerm::WandTerm(std::unique_ptr<QueryNodeResultBase> result_base, const string & index, uint32_t num_terms)
+ : DotProductTerm(std::move(result_base), index, num_terms),
+ _score_threshold(0.0)
+{
+}
+
+WandTerm::~WandTerm() = default;
+
+bool
+WandTerm::evaluate() const
+{
+ if (_score_threshold <= 0.0) {
+ return DotProductTerm::evaluate();
+ }
+ Scores scores;
+ build_scores(scores);
+ for (auto &field_and_score : scores) {
+ if (field_and_score.second > _score_threshold) {
+ return true;
+ }
+ }
+ return false;
+}
+
+void
+WandTerm::unpack_match_data(uint32_t docid, const ITermData& td, MatchData& match_data)
+{
+ Scores scores;
+ build_scores(scores);
+ unpack_scores(scores, _score_threshold, docid, td, match_data);
+}
+
+}
diff --git a/searchlib/src/vespa/searchlib/query/streaming/wand_term.h b/searchlib/src/vespa/searchlib/query/streaming/wand_term.h
new file mode 100644
index 00000000000..1b342834216
--- /dev/null
+++ b/searchlib/src/vespa/searchlib/query/streaming/wand_term.h
@@ -0,0 +1,22 @@
+// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+
+#pragma once
+
+#include "dot_product_term.h"
+
+namespace search::streaming {
+
+/*
+ * A wand query term for streaming search.
+ */
+class WandTerm : public DotProductTerm {
+ double _score_threshold;
+public:
+ WandTerm(std::unique_ptr<QueryNodeResultBase> result_base, const string& index, uint32_t num_terms);
+ ~WandTerm() override;
+ void set_score_threshold(double value) { _score_threshold = value; }
+ bool evaluate() const override;
+ void unpack_match_data(uint32_t docid, const fef::ITermData& td, fef::MatchData& match_data) override;
+};
+
+}