From f7a4025bf64285ea5cbd3318e2b593f346a49050 Mon Sep 17 00:00:00 2001 From: Tor Egge Date: Fri, 12 Jan 2024 13:54:48 +0100 Subject: Calculate raw score for streaming search wand. --- searchlib/src/tests/query/streaming_query_test.cpp | 57 ++++++++++++++++++++++ .../vespa/searchlib/query/streaming/CMakeLists.txt | 1 + .../searchlib/query/streaming/dot_product_term.cpp | 31 +++++++++--- .../searchlib/query/streaming/dot_product_term.h | 6 +++ .../src/vespa/searchlib/query/streaming/query.h | 2 - .../vespa/searchlib/query/streaming/querynode.cpp | 22 +++++++-- .../vespa/searchlib/query/streaming/querynode.h | 1 + .../vespa/searchlib/query/streaming/wand_term.cpp | 44 +++++++++++++++++ .../vespa/searchlib/query/streaming/wand_term.h | 22 +++++++++ 9 files changed, 171 insertions(+), 15 deletions(-) create mode 100644 searchlib/src/vespa/searchlib/query/streaming/wand_term.cpp create mode 100644 searchlib/src/vespa/searchlib/query/streaming/wand_term.h (limited to 'searchlib') diff --git a/searchlib/src/tests/query/streaming_query_test.cpp b/searchlib/src/tests/query/streaming_query_test.cpp index c4ef2028123..2a71fce85c3 100644 --- a/searchlib/src/tests/query/streaming_query_test.cpp +++ b/searchlib/src/tests/query/streaming_query_test.cpp @@ -6,6 +6,7 @@ #include #include #include +#include #include #include #include @@ -957,6 +958,62 @@ TEST(StreamingQueryTest, dot_product_term) EXPECT_EQ(-17 * 27 + 9 * 2, tmd1->getRawScore()); } +namespace { + +constexpr double exp_wand_score = 13 * 27 + 4 * 2; +constexpr double exp_wand_hidden_score = 17 * 27 + 9 * 2; + +void +check_wand_term(double limit, const vespalib::string& label) +{ + SCOPED_TRACE(label); + search::streaming::WandTerm term({}, "index", 2); + term.add_term(std::make_unique(std::unique_ptr(), "7", "", QueryTermSimple::Type::WORD)); + term.get_terms().back()->setWeight(Weight(27)); + term.add_term(std::make_unique(std::unique_ptr(), "9", "", QueryTermSimple::Type::WORD)); + term.get_terms().back()->setWeight(Weight(2)); + EXPECT_EQ(2, term.get_terms().size()); + term.set_score_threshold(limit); + SimpleTermData td; + td.addField(10); + td.addField(11); + td.addField(12); + td.lookupField(10)->setHandle(0); + td.lookupField(12)->setHandle(1); + EXPECT_FALSE(term.evaluate()); + auto& q0 = *term.get_terms()[0]; + q0.add(0, 11, 0, 17); + q0.add(0, 12, 0, 13); + auto& q1 = *term.get_terms()[1]; + q1.add(0, 11, 0, 9); + q1.add(0, 12, 0, 4); + EXPECT_EQ(limit < exp_wand_hidden_score, term.evaluate()); + MatchData md(MatchData::params().numTermFields(2)); + term.unpack_match_data(23, td, md); + auto tmd0 = md.resolveTermField(0); + EXPECT_NE(23, tmd0->getDocId()); + auto tmd1 = md.resolveTermField(1); + if (limit < exp_wand_score) { + EXPECT_EQ(23, tmd1->getDocId()); + EXPECT_EQ(exp_wand_score, tmd1->getRawScore()); + } else { + EXPECT_NE(23, tmd1->getDocId()); + } +} + +} + +TEST(StreamingQueryTest, wand_term) +{ + check_wand_term(0.0, "no limit"); + check_wand_term(exp_wand_score - 1, "score above limit"); + check_wand_term(exp_wand_score, "score at limit"); + check_wand_term(exp_wand_score + 1, "score below limit"); + check_wand_term(exp_wand_hidden_score - 1, "hidden score above limit"); + check_wand_term(exp_wand_hidden_score, "hidden score at limit"); + check_wand_term(exp_wand_hidden_score + 1, "hidden score below limit"); +} + TEST(StreamingQueryTest, control_the_size_of_query_terms) { EXPECT_EQ(112u, sizeof(QueryTermSimple)); diff --git a/searchlib/src/vespa/searchlib/query/streaming/CMakeLists.txt b/searchlib/src/vespa/searchlib/query/streaming/CMakeLists.txt index 0813292a9da..9b53407aff5 100644 --- a/searchlib/src/vespa/searchlib/query/streaming/CMakeLists.txt +++ b/searchlib/src/vespa/searchlib/query/streaming/CMakeLists.txt @@ -9,5 +9,6 @@ vespa_add_library(searchlib_query_streaming OBJECT querynode.cpp querynoderesultbase.cpp queryterm.cpp + wand_term.cpp DEPENDS ) diff --git a/searchlib/src/vespa/searchlib/query/streaming/dot_product_term.cpp b/searchlib/src/vespa/searchlib/query/streaming/dot_product_term.cpp index 9bb6d8c3342..1871bda564d 100644 --- a/searchlib/src/vespa/searchlib/query/streaming/dot_product_term.cpp +++ b/searchlib/src/vespa/searchlib/query/streaming/dot_product_term.cpp @@ -18,29 +18,44 @@ DotProductTerm::DotProductTerm(std::unique_ptr result_base, DotProductTerm::~DotProductTerm() = default; void -DotProductTerm::unpack_match_data(uint32_t docid, const ITermData& td, MatchData& match_data) +DotProductTerm::build_scores(Scores& scores) const { - vespalib::hash_map scores; HitList hl_store; for (const auto& term : _terms) { auto& hl = term->evaluateHits(hl_store); for (auto& hit : hl) { - scores[hit.context()] += term->weight().percent() * hit.weight(); + scores[hit.context()] += ((int64_t)term->weight().percent()) * hit.weight(); } } +} + +void +DotProductTerm::unpack_scores(Scores& scores, std::optional score_threshold, uint32_t docid, const fef::ITermData& td, fef::MatchData& match_data) +{ auto num_fields = td.numFields(); for (uint32_t field_idx = 0; field_idx < num_fields; ++field_idx) { auto& tfd = td.field(field_idx); auto field_id = tfd.getFieldId(); if (scores.contains(field_id)) { - auto handle = tfd.getHandle(); - if (handle != fef::IllegalHandle) { - auto tmd = match_data.resolveTermField(tfd.getHandle()); - tmd->setFieldId(field_id); - tmd->setRawScore(docid, scores[field_id]); + auto score = scores[field_id]; + if (!score_threshold.has_value() || score_threshold.value() < score) { + auto handle = tfd.getHandle(); + if (handle != fef::IllegalHandle) { + auto tmd = match_data.resolveTermField(tfd.getHandle()); + tmd->setFieldId(field_id); + tmd->setRawScore(docid, score); + } } } } } +void +DotProductTerm::unpack_match_data(uint32_t docid, const ITermData& td, MatchData& match_data) +{ + Scores scores; + build_scores(scores); + unpack_scores(scores, std::nullopt, docid, td, match_data); +} + } diff --git a/searchlib/src/vespa/searchlib/query/streaming/dot_product_term.h b/searchlib/src/vespa/searchlib/query/streaming/dot_product_term.h index 77cac693781..3702bd4721c 100644 --- a/searchlib/src/vespa/searchlib/query/streaming/dot_product_term.h +++ b/searchlib/src/vespa/searchlib/query/streaming/dot_product_term.h @@ -3,6 +3,8 @@ #pragma once #include "multi_term.h" +#include +#include namespace search::streaming { @@ -10,6 +12,10 @@ namespace search::streaming { * A dot product query term for streaming search. */ class DotProductTerm : public MultiTerm { +protected: + using Scores = vespalib::hash_map; + void build_scores(Scores& scores) const; + void unpack_scores(Scores& scores, std::optional score_threshold, uint32_t docid, const fef::ITermData& td, fef::MatchData& match_data); public: DotProductTerm(std::unique_ptr result_base, const string& index, uint32_t num_terms); ~DotProductTerm() override; diff --git a/searchlib/src/vespa/searchlib/query/streaming/query.h b/searchlib/src/vespa/searchlib/query/streaming/query.h index 3904f743d26..8befa2fe7fa 100644 --- a/searchlib/src/vespa/searchlib/query/streaming/query.h +++ b/searchlib/src/vespa/searchlib/query/streaming/query.h @@ -90,8 +90,6 @@ public: bool evaluate() const override; bool isFlattenable(ParseItem::ItemType type) const override { return (type == ParseItem::ITEM_OR) || - (type == ParseItem::ITEM_DOT_PRODUCT) || - (type == ParseItem::ITEM_WAND) || (type == ParseItem::ITEM_WEAK_AND); } }; diff --git a/searchlib/src/vespa/searchlib/query/streaming/querynode.cpp b/searchlib/src/vespa/searchlib/query/streaming/querynode.cpp index 1e43c32a263..c24f41d16cf 100644 --- a/searchlib/src/vespa/searchlib/query/streaming/querynode.cpp +++ b/searchlib/src/vespa/searchlib/query/streaming/querynode.cpp @@ -5,6 +5,7 @@ #include #include #include +#include #include #include #include @@ -40,7 +41,6 @@ QueryNode::Build(const QueryNode * parent, const QueryNodeResultFactory & factor case ParseItem::ITEM_WEAK_AND: case ParseItem::ITEM_EQUIV: case ParseItem::ITEM_WEIGHTED_SET: - case ParseItem::ITEM_WAND: case ParseItem::ITEM_NOT: case ParseItem::ITEM_PHRASE: case ParseItem::ITEM_SAME_ELEMENT: @@ -56,9 +56,7 @@ QueryNode::Build(const QueryNode * parent, const QueryNodeResultFactory & factor } if ((type == ParseItem::ITEM_WEAK_AND) || (type == ParseItem::ITEM_WEIGHTED_SET) || - (type == ParseItem::ITEM_DOT_PRODUCT) || - (type == ParseItem::ITEM_SAME_ELEMENT) || - (type == ParseItem::ITEM_WAND)) + (type == ParseItem::ITEM_SAME_ELEMENT)) { qn->setIndex(queryRep.getIndexName()); } @@ -191,6 +189,9 @@ QueryNode::Build(const QueryNode * parent, const QueryNodeResultFactory & factor case ParseItem::ITEM_DOT_PRODUCT: qn = build_dot_product_term(factory, queryRep); break; + case ParseItem::ITEM_WAND: + qn = build_wand_term(factory, queryRep); + break; default: skip_unknown(queryRep); break; @@ -251,13 +252,24 @@ QueryNode::populate_multi_term(Normalizing string_normalize_mode, MultiTerm& mt, std::unique_ptr QueryNode::build_dot_product_term(const QueryNodeResultFactory& factory, SimpleQueryStackDumpIterator& queryRep) { - auto dp =std::make_unique(factory.create(), queryRep.getIndexName(), queryRep.getArity()); + auto dp = std::make_unique(factory.create(), queryRep.getIndexName(), queryRep.getArity()); dp->setWeight(queryRep.GetWeight()); dp->setUniqueId(queryRep.getUniqueId()); populate_multi_term(factory.normalizing_mode(dp->index()), *dp, queryRep); return dp; } +std::unique_ptr +QueryNode::build_wand_term(const QueryNodeResultFactory& factory, SimpleQueryStackDumpIterator& queryRep) +{ + auto wand = std::make_unique(factory.create(), queryRep.getIndexName(), queryRep.getArity()); + wand->setWeight(queryRep.GetWeight()); + wand->setUniqueId(queryRep.getUniqueId()); + wand->set_score_threshold(queryRep.getScoreThreshold()); + populate_multi_term(factory.normalizing_mode(wand->index()), *wand, queryRep); + return wand; +} + void QueryNode::skip_unknown(SimpleQueryStackDumpIterator& queryRep) { diff --git a/searchlib/src/vespa/searchlib/query/streaming/querynode.h b/searchlib/src/vespa/searchlib/query/streaming/querynode.h index 576d614e58b..a0561b2e52e 100644 --- a/searchlib/src/vespa/searchlib/query/streaming/querynode.h +++ b/searchlib/src/vespa/searchlib/query/streaming/querynode.h @@ -31,6 +31,7 @@ class QueryNode static std::unique_ptr build_nearest_neighbor_query_node(const QueryNodeResultFactory& factory, SimpleQueryStackDumpIterator& queryRep); static void populate_multi_term(Normalizing string_normalize_mode, MultiTerm& mt, SimpleQueryStackDumpIterator& queryRep); static std::unique_ptr build_dot_product_term(const QueryNodeResultFactory& factory, SimpleQueryStackDumpIterator& queryRep); + static std::unique_ptr build_wand_term(const QueryNodeResultFactory& factory, SimpleQueryStackDumpIterator& queryRep); static void skip_unknown(SimpleQueryStackDumpIterator& queryRep); public: using UP = std::unique_ptr; diff --git a/searchlib/src/vespa/searchlib/query/streaming/wand_term.cpp b/searchlib/src/vespa/searchlib/query/streaming/wand_term.cpp new file mode 100644 index 00000000000..a561adf5b42 --- /dev/null +++ b/searchlib/src/vespa/searchlib/query/streaming/wand_term.cpp @@ -0,0 +1,44 @@ +// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. + +#include "wand_term.h" +#include +#include + +using search::fef::ITermData; +using search::fef::MatchData; + +namespace search::streaming { + +WandTerm::WandTerm(std::unique_ptr result_base, const string & index, uint32_t num_terms) + : DotProductTerm(std::move(result_base), index, num_terms), + _score_threshold(0.0) +{ +} + +WandTerm::~WandTerm() = default; + +bool +WandTerm::evaluate() const +{ + if (_score_threshold <= 0.0) { + return DotProductTerm::evaluate(); + } + Scores scores; + build_scores(scores); + for (auto &field_and_score : scores) { + if (field_and_score.second > _score_threshold) { + return true; + } + } + return false; +} + +void +WandTerm::unpack_match_data(uint32_t docid, const ITermData& td, MatchData& match_data) +{ + Scores scores; + build_scores(scores); + unpack_scores(scores, _score_threshold, docid, td, match_data); +} + +} diff --git a/searchlib/src/vespa/searchlib/query/streaming/wand_term.h b/searchlib/src/vespa/searchlib/query/streaming/wand_term.h new file mode 100644 index 00000000000..1b342834216 --- /dev/null +++ b/searchlib/src/vespa/searchlib/query/streaming/wand_term.h @@ -0,0 +1,22 @@ +// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. + +#pragma once + +#include "dot_product_term.h" + +namespace search::streaming { + +/* + * A wand query term for streaming search. + */ +class WandTerm : public DotProductTerm { + double _score_threshold; +public: + WandTerm(std::unique_ptr result_base, const string& index, uint32_t num_terms); + ~WandTerm() override; + void set_score_threshold(double value) { _score_threshold = value; } + bool evaluate() const override; + void unpack_match_data(uint32_t docid, const fef::ITermData& td, fef::MatchData& match_data) override; +}; + +} -- cgit v1.2.3 From 97c388081bdf51a6d086c1e51e3919d8eb427117 Mon Sep 17 00:00:00 2001 From: Tor Egge Date: Fri, 12 Jan 2024 15:38:29 +0100 Subject: Rename constants and add comment for streaming search wand term unit test. --- searchlib/src/tests/query/streaming_query_test.cpp | 28 +++++++++++++--------- 1 file changed, 17 insertions(+), 11 deletions(-) (limited to 'searchlib') diff --git a/searchlib/src/tests/query/streaming_query_test.cpp b/searchlib/src/tests/query/streaming_query_test.cpp index 2a71fce85c3..7c4b7555158 100644 --- a/searchlib/src/tests/query/streaming_query_test.cpp +++ b/searchlib/src/tests/query/streaming_query_test.cpp @@ -960,8 +960,8 @@ TEST(StreamingQueryTest, dot_product_term) namespace { -constexpr double exp_wand_score = 13 * 27 + 4 * 2; -constexpr double exp_wand_hidden_score = 17 * 27 + 9 * 2; +constexpr double exp_wand_score_field_12 = 13 * 27 + 4 * 2; +constexpr double exp_wand_score_field_11 = 17 * 27 + 9 * 2; void check_wand_term(double limit, const vespalib::string& label) @@ -975,6 +975,12 @@ check_wand_term(double limit, const vespalib::string& label) EXPECT_EQ(2, term.get_terms().size()); term.set_score_threshold(limit); SimpleTermData td; + /* + * Search in fields 10, 11 and 12 (cf. fieldset in schema). + * Fields 11 and 12 have content for doc containing the keys. + * Fields 10 and 12 have valid handles and can be used for ranking. + * Field 11 does not have a valid handle, thus no associated match data. + */ td.addField(10); td.addField(11); td.addField(12); @@ -987,15 +993,15 @@ check_wand_term(double limit, const vespalib::string& label) auto& q1 = *term.get_terms()[1]; q1.add(0, 11, 0, 9); q1.add(0, 12, 0, 4); - EXPECT_EQ(limit < exp_wand_hidden_score, term.evaluate()); + EXPECT_EQ(limit < exp_wand_score_field_11, term.evaluate()); MatchData md(MatchData::params().numTermFields(2)); term.unpack_match_data(23, td, md); auto tmd0 = md.resolveTermField(0); EXPECT_NE(23, tmd0->getDocId()); auto tmd1 = md.resolveTermField(1); - if (limit < exp_wand_score) { + if (limit < exp_wand_score_field_12) { EXPECT_EQ(23, tmd1->getDocId()); - EXPECT_EQ(exp_wand_score, tmd1->getRawScore()); + EXPECT_EQ(exp_wand_score_field_12, tmd1->getRawScore()); } else { EXPECT_NE(23, tmd1->getDocId()); } @@ -1006,12 +1012,12 @@ check_wand_term(double limit, const vespalib::string& label) TEST(StreamingQueryTest, wand_term) { check_wand_term(0.0, "no limit"); - check_wand_term(exp_wand_score - 1, "score above limit"); - check_wand_term(exp_wand_score, "score at limit"); - check_wand_term(exp_wand_score + 1, "score below limit"); - check_wand_term(exp_wand_hidden_score - 1, "hidden score above limit"); - check_wand_term(exp_wand_hidden_score, "hidden score at limit"); - check_wand_term(exp_wand_hidden_score + 1, "hidden score below limit"); + check_wand_term(exp_wand_score_field_12 - 1, "score above limit"); + check_wand_term(exp_wand_score_field_12, "score at limit"); + check_wand_term(exp_wand_score_field_12 + 1, "score below limit"); + check_wand_term(exp_wand_score_field_11 - 1, "hidden score above limit"); + check_wand_term(exp_wand_score_field_11, "hidden score at limit"); + check_wand_term(exp_wand_score_field_11 + 1, "hidden score below limit"); } TEST(StreamingQueryTest, control_the_size_of_query_terms) -- cgit v1.2.3