diff options
10 files changed, 178 insertions, 16 deletions
diff --git a/container-search/src/main/java/com/yahoo/search/yql/YqlParser.java b/container-search/src/main/java/com/yahoo/search/yql/YqlParser.java index 29f2d9aff9a..7ae02c18e7a 100644 --- a/container-search/src/main/java/com/yahoo/search/yql/YqlParser.java +++ b/container-search/src/main/java/com/yahoo/search/yql/YqlParser.java @@ -595,7 +595,7 @@ public class YqlParser implements Parser { } WandItem out = new WandItem(getIndex(args.get(0)), targetNumHits); Double scoreThreshold = getAnnotation(ast, SCORE_THRESHOLD, Double.class, null, - "min score for hit inclusion"); + "score must be above this threshold for hit inclusion"); if (scoreThreshold != null) { out.setScoreThreshold(scoreThreshold); } diff --git a/searchlib/src/tests/query/streaming_query_test.cpp b/searchlib/src/tests/query/streaming_query_test.cpp index c4ef2028123..7c4b7555158 100644 --- a/searchlib/src/tests/query/streaming_query_test.cpp +++ b/searchlib/src/tests/query/streaming_query_test.cpp @@ -6,6 +6,7 @@ #include <vespa/searchlib/query/streaming/in_term.h> #include <vespa/searchlib/query/streaming/query.h> #include <vespa/searchlib/query/streaming/nearest_neighbor_query_node.h> +#include <vespa/searchlib/query/streaming/wand_term.h> #include <vespa/searchlib/query/tree/querybuilder.h> #include <vespa/searchlib/query/tree/simplequery.h> #include <vespa/searchlib/query/tree/stackdumpcreator.h> @@ -957,6 +958,68 @@ TEST(StreamingQueryTest, dot_product_term) EXPECT_EQ(-17 * 27 + 9 * 2, tmd1->getRawScore()); } +namespace { + +constexpr double exp_wand_score_field_12 = 13 * 27 + 4 * 2; +constexpr double exp_wand_score_field_11 = 17 * 27 + 9 * 2; + +void +check_wand_term(double limit, const vespalib::string& label) +{ + SCOPED_TRACE(label); + search::streaming::WandTerm term({}, "index", 2); + term.add_term(std::make_unique<QueryTerm>(std::unique_ptr<QueryNodeResultBase>(), "7", "", QueryTermSimple::Type::WORD)); + term.get_terms().back()->setWeight(Weight(27)); + term.add_term(std::make_unique<QueryTerm>(std::unique_ptr<QueryNodeResultBase>(), "9", "", QueryTermSimple::Type::WORD)); + term.get_terms().back()->setWeight(Weight(2)); + EXPECT_EQ(2, term.get_terms().size()); + term.set_score_threshold(limit); + SimpleTermData td; + /* + * Search in fields 10, 11 and 12 (cf. fieldset in schema). + * Fields 11 and 12 have content for doc containing the keys. + * Fields 10 and 12 have valid handles and can be used for ranking. + * Field 11 does not have a valid handle, thus no associated match data. + */ + td.addField(10); + td.addField(11); + td.addField(12); + td.lookupField(10)->setHandle(0); + td.lookupField(12)->setHandle(1); + EXPECT_FALSE(term.evaluate()); + auto& q0 = *term.get_terms()[0]; + q0.add(0, 11, 0, 17); + q0.add(0, 12, 0, 13); + auto& q1 = *term.get_terms()[1]; + q1.add(0, 11, 0, 9); + q1.add(0, 12, 0, 4); + EXPECT_EQ(limit < exp_wand_score_field_11, term.evaluate()); + MatchData md(MatchData::params().numTermFields(2)); + term.unpack_match_data(23, td, md); + auto tmd0 = md.resolveTermField(0); + EXPECT_NE(23, tmd0->getDocId()); + auto tmd1 = md.resolveTermField(1); + if (limit < exp_wand_score_field_12) { + EXPECT_EQ(23, tmd1->getDocId()); + EXPECT_EQ(exp_wand_score_field_12, tmd1->getRawScore()); + } else { + EXPECT_NE(23, tmd1->getDocId()); + } +} + +} + +TEST(StreamingQueryTest, wand_term) +{ + check_wand_term(0.0, "no limit"); + check_wand_term(exp_wand_score_field_12 - 1, "score above limit"); + check_wand_term(exp_wand_score_field_12, "score at limit"); + check_wand_term(exp_wand_score_field_12 + 1, "score below limit"); + check_wand_term(exp_wand_score_field_11 - 1, "hidden score above limit"); + check_wand_term(exp_wand_score_field_11, "hidden score at limit"); + check_wand_term(exp_wand_score_field_11 + 1, "hidden score below limit"); +} + TEST(StreamingQueryTest, control_the_size_of_query_terms) { EXPECT_EQ(112u, sizeof(QueryTermSimple)); diff --git a/searchlib/src/vespa/searchlib/query/streaming/CMakeLists.txt b/searchlib/src/vespa/searchlib/query/streaming/CMakeLists.txt index 0813292a9da..9b53407aff5 100644 --- a/searchlib/src/vespa/searchlib/query/streaming/CMakeLists.txt +++ b/searchlib/src/vespa/searchlib/query/streaming/CMakeLists.txt @@ -9,5 +9,6 @@ vespa_add_library(searchlib_query_streaming OBJECT querynode.cpp querynoderesultbase.cpp queryterm.cpp + wand_term.cpp DEPENDS ) diff --git a/searchlib/src/vespa/searchlib/query/streaming/dot_product_term.cpp b/searchlib/src/vespa/searchlib/query/streaming/dot_product_term.cpp index 9bb6d8c3342..1871bda564d 100644 --- a/searchlib/src/vespa/searchlib/query/streaming/dot_product_term.cpp +++ b/searchlib/src/vespa/searchlib/query/streaming/dot_product_term.cpp @@ -18,29 +18,44 @@ DotProductTerm::DotProductTerm(std::unique_ptr<QueryNodeResultBase> result_base, DotProductTerm::~DotProductTerm() = default; void -DotProductTerm::unpack_match_data(uint32_t docid, const ITermData& td, MatchData& match_data) +DotProductTerm::build_scores(Scores& scores) const { - vespalib::hash_map<uint32_t,double> scores; HitList hl_store; for (const auto& term : _terms) { auto& hl = term->evaluateHits(hl_store); for (auto& hit : hl) { - scores[hit.context()] += term->weight().percent() * hit.weight(); + scores[hit.context()] += ((int64_t)term->weight().percent()) * hit.weight(); } } +} + +void +DotProductTerm::unpack_scores(Scores& scores, std::optional<double> score_threshold, uint32_t docid, const fef::ITermData& td, fef::MatchData& match_data) +{ auto num_fields = td.numFields(); for (uint32_t field_idx = 0; field_idx < num_fields; ++field_idx) { auto& tfd = td.field(field_idx); auto field_id = tfd.getFieldId(); if (scores.contains(field_id)) { - auto handle = tfd.getHandle(); - if (handle != fef::IllegalHandle) { - auto tmd = match_data.resolveTermField(tfd.getHandle()); - tmd->setFieldId(field_id); - tmd->setRawScore(docid, scores[field_id]); + auto score = scores[field_id]; + if (!score_threshold.has_value() || score_threshold.value() < score) { + auto handle = tfd.getHandle(); + if (handle != fef::IllegalHandle) { + auto tmd = match_data.resolveTermField(tfd.getHandle()); + tmd->setFieldId(field_id); + tmd->setRawScore(docid, score); + } } } } } +void +DotProductTerm::unpack_match_data(uint32_t docid, const ITermData& td, MatchData& match_data) +{ + Scores scores; + build_scores(scores); + unpack_scores(scores, std::nullopt, docid, td, match_data); +} + } diff --git a/searchlib/src/vespa/searchlib/query/streaming/dot_product_term.h b/searchlib/src/vespa/searchlib/query/streaming/dot_product_term.h index 77cac693781..3702bd4721c 100644 --- a/searchlib/src/vespa/searchlib/query/streaming/dot_product_term.h +++ b/searchlib/src/vespa/searchlib/query/streaming/dot_product_term.h @@ -3,6 +3,8 @@ #pragma once #include "multi_term.h" +#include <vespa/vespalib/stllike/hash_map.h> +#include <optional> namespace search::streaming { @@ -10,6 +12,10 @@ namespace search::streaming { * A dot product query term for streaming search. */ class DotProductTerm : public MultiTerm { +protected: + using Scores = vespalib::hash_map<uint32_t,double>; + void build_scores(Scores& scores) const; + void unpack_scores(Scores& scores, std::optional<double> score_threshold, uint32_t docid, const fef::ITermData& td, fef::MatchData& match_data); public: DotProductTerm(std::unique_ptr<QueryNodeResultBase> result_base, const string& index, uint32_t num_terms); ~DotProductTerm() override; diff --git a/searchlib/src/vespa/searchlib/query/streaming/query.h b/searchlib/src/vespa/searchlib/query/streaming/query.h index 3904f743d26..8befa2fe7fa 100644 --- a/searchlib/src/vespa/searchlib/query/streaming/query.h +++ b/searchlib/src/vespa/searchlib/query/streaming/query.h @@ -90,8 +90,6 @@ public: bool evaluate() const override; bool isFlattenable(ParseItem::ItemType type) const override { return (type == ParseItem::ITEM_OR) || - (type == ParseItem::ITEM_DOT_PRODUCT) || - (type == ParseItem::ITEM_WAND) || (type == ParseItem::ITEM_WEAK_AND); } }; diff --git a/searchlib/src/vespa/searchlib/query/streaming/querynode.cpp b/searchlib/src/vespa/searchlib/query/streaming/querynode.cpp index 1e43c32a263..c24f41d16cf 100644 --- a/searchlib/src/vespa/searchlib/query/streaming/querynode.cpp +++ b/searchlib/src/vespa/searchlib/query/streaming/querynode.cpp @@ -5,6 +5,7 @@ #include <vespa/searchlib/parsequery/stackdumpiterator.h> #include <vespa/searchlib/query/streaming/dot_product_term.h> #include <vespa/searchlib/query/streaming/in_term.h> +#include <vespa/searchlib/query/streaming/wand_term.h> #include <vespa/searchlib/query/tree/term_vector.h> #include <charconv> #include <vespa/log/log.h> @@ -40,7 +41,6 @@ QueryNode::Build(const QueryNode * parent, const QueryNodeResultFactory & factor case ParseItem::ITEM_WEAK_AND: case ParseItem::ITEM_EQUIV: case ParseItem::ITEM_WEIGHTED_SET: - case ParseItem::ITEM_WAND: case ParseItem::ITEM_NOT: case ParseItem::ITEM_PHRASE: case ParseItem::ITEM_SAME_ELEMENT: @@ -56,9 +56,7 @@ QueryNode::Build(const QueryNode * parent, const QueryNodeResultFactory & factor } if ((type == ParseItem::ITEM_WEAK_AND) || (type == ParseItem::ITEM_WEIGHTED_SET) || - (type == ParseItem::ITEM_DOT_PRODUCT) || - (type == ParseItem::ITEM_SAME_ELEMENT) || - (type == ParseItem::ITEM_WAND)) + (type == ParseItem::ITEM_SAME_ELEMENT)) { qn->setIndex(queryRep.getIndexName()); } @@ -191,6 +189,9 @@ QueryNode::Build(const QueryNode * parent, const QueryNodeResultFactory & factor case ParseItem::ITEM_DOT_PRODUCT: qn = build_dot_product_term(factory, queryRep); break; + case ParseItem::ITEM_WAND: + qn = build_wand_term(factory, queryRep); + break; default: skip_unknown(queryRep); break; @@ -251,13 +252,24 @@ QueryNode::populate_multi_term(Normalizing string_normalize_mode, MultiTerm& mt, std::unique_ptr<QueryNode> QueryNode::build_dot_product_term(const QueryNodeResultFactory& factory, SimpleQueryStackDumpIterator& queryRep) { - auto dp =std::make_unique<DotProductTerm>(factory.create(), queryRep.getIndexName(), queryRep.getArity()); + auto dp = std::make_unique<DotProductTerm>(factory.create(), queryRep.getIndexName(), queryRep.getArity()); dp->setWeight(queryRep.GetWeight()); dp->setUniqueId(queryRep.getUniqueId()); populate_multi_term(factory.normalizing_mode(dp->index()), *dp, queryRep); return dp; } +std::unique_ptr<QueryNode> +QueryNode::build_wand_term(const QueryNodeResultFactory& factory, SimpleQueryStackDumpIterator& queryRep) +{ + auto wand = std::make_unique<WandTerm>(factory.create(), queryRep.getIndexName(), queryRep.getArity()); + wand->setWeight(queryRep.GetWeight()); + wand->setUniqueId(queryRep.getUniqueId()); + wand->set_score_threshold(queryRep.getScoreThreshold()); + populate_multi_term(factory.normalizing_mode(wand->index()), *wand, queryRep); + return wand; +} + void QueryNode::skip_unknown(SimpleQueryStackDumpIterator& queryRep) { diff --git a/searchlib/src/vespa/searchlib/query/streaming/querynode.h b/searchlib/src/vespa/searchlib/query/streaming/querynode.h index 576d614e58b..a0561b2e52e 100644 --- a/searchlib/src/vespa/searchlib/query/streaming/querynode.h +++ b/searchlib/src/vespa/searchlib/query/streaming/querynode.h @@ -31,6 +31,7 @@ class QueryNode static std::unique_ptr<QueryNode> build_nearest_neighbor_query_node(const QueryNodeResultFactory& factory, SimpleQueryStackDumpIterator& queryRep); static void populate_multi_term(Normalizing string_normalize_mode, MultiTerm& mt, SimpleQueryStackDumpIterator& queryRep); static std::unique_ptr<QueryNode> build_dot_product_term(const QueryNodeResultFactory& factory, SimpleQueryStackDumpIterator& queryRep); + static std::unique_ptr<QueryNode> build_wand_term(const QueryNodeResultFactory& factory, SimpleQueryStackDumpIterator& queryRep); static void skip_unknown(SimpleQueryStackDumpIterator& queryRep); public: using UP = std::unique_ptr<QueryNode>; diff --git a/searchlib/src/vespa/searchlib/query/streaming/wand_term.cpp b/searchlib/src/vespa/searchlib/query/streaming/wand_term.cpp new file mode 100644 index 00000000000..a561adf5b42 --- /dev/null +++ b/searchlib/src/vespa/searchlib/query/streaming/wand_term.cpp @@ -0,0 +1,44 @@ +// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. + +#include "wand_term.h" +#include <vespa/searchlib/fef/itermdata.h> +#include <vespa/searchlib/fef/matchdata.h> + +using search::fef::ITermData; +using search::fef::MatchData; + +namespace search::streaming { + +WandTerm::WandTerm(std::unique_ptr<QueryNodeResultBase> result_base, const string & index, uint32_t num_terms) + : DotProductTerm(std::move(result_base), index, num_terms), + _score_threshold(0.0) +{ +} + +WandTerm::~WandTerm() = default; + +bool +WandTerm::evaluate() const +{ + if (_score_threshold <= 0.0) { + return DotProductTerm::evaluate(); + } + Scores scores; + build_scores(scores); + for (auto &field_and_score : scores) { + if (field_and_score.second > _score_threshold) { + return true; + } + } + return false; +} + +void +WandTerm::unpack_match_data(uint32_t docid, const ITermData& td, MatchData& match_data) +{ + Scores scores; + build_scores(scores); + unpack_scores(scores, _score_threshold, docid, td, match_data); +} + +} diff --git a/searchlib/src/vespa/searchlib/query/streaming/wand_term.h b/searchlib/src/vespa/searchlib/query/streaming/wand_term.h new file mode 100644 index 00000000000..1b342834216 --- /dev/null +++ b/searchlib/src/vespa/searchlib/query/streaming/wand_term.h @@ -0,0 +1,22 @@ +// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. + +#pragma once + +#include "dot_product_term.h" + +namespace search::streaming { + +/* + * A wand query term for streaming search. + */ +class WandTerm : public DotProductTerm { + double _score_threshold; +public: + WandTerm(std::unique_ptr<QueryNodeResultBase> result_base, const string& index, uint32_t num_terms); + ~WandTerm() override; + void set_score_threshold(double value) { _score_threshold = value; } + bool evaluate() const override; + void unpack_match_data(uint32_t docid, const fef::ITermData& td, fef::MatchData& match_data) override; +}; + +} |