diff options
author | Tor Egge <Tor.Egge@online.no> | 2024-01-15 16:04:45 +0100 |
---|---|---|
committer | Tor Egge <Tor.Egge@online.no> | 2024-01-15 16:04:45 +0100 |
commit | bf4f3338864574021ce7260dd527310e84bfeb32 (patch) | |
tree | f2ebcad28e51e55e13a871c52a47e36b72648627 | |
parent | 008338e2cfc4b455a07b8a8ab129b8cbbf9d0f2f (diff) |
Add WeightedSetTerm for streaming search.
7 files changed, 127 insertions, 4 deletions
diff --git a/searchlib/src/tests/query/streaming_query_test.cpp b/searchlib/src/tests/query/streaming_query_test.cpp index 7c4b7555158..52f01af3dff 100644 --- a/searchlib/src/tests/query/streaming_query_test.cpp +++ b/searchlib/src/tests/query/streaming_query_test.cpp @@ -7,6 +7,7 @@ #include <vespa/searchlib/query/streaming/query.h> #include <vespa/searchlib/query/streaming/nearest_neighbor_query_node.h> #include <vespa/searchlib/query/streaming/wand_term.h> +#include <vespa/searchlib/query/streaming/weighted_set_term.h> #include <vespa/searchlib/query/tree/querybuilder.h> #include <vespa/searchlib/query/tree/simplequery.h> #include <vespa/searchlib/query/tree/stackdumpcreator.h> @@ -1020,6 +1021,42 @@ TEST(StreamingQueryTest, wand_term) check_wand_term(exp_wand_score_field_11 + 1, "hidden score below limit"); } +TEST(StreamingQueryTest, weighted_set_term) +{ + search::streaming::WeightedSetTerm term({}, "index", 2); + term.add_term(std::make_unique<QueryTerm>(std::unique_ptr<QueryNodeResultBase>(), "7", "", QueryTermSimple::Type::WORD)); + term.get_terms().back()->setWeight(Weight(4)); + term.add_term(std::make_unique<QueryTerm>(std::unique_ptr<QueryNodeResultBase>(), "9", "", QueryTermSimple::Type::WORD)); + term.get_terms().back()->setWeight(Weight(13)); + EXPECT_EQ(2, term.get_terms().size()); + SimpleTermData td; + td.addField(10); + td.addField(11); + td.addField(12); + td.lookupField(10)->setHandle(0); + td.lookupField(12)->setHandle(1); + EXPECT_FALSE(term.evaluate()); + auto& q0 = *term.get_terms()[0]; + q0.add(0, 11, 0, 10); + q0.add(0, 12, 0, 10); + auto& q1 = *term.get_terms()[1]; + q1.add(0, 11, 0, 10); + q1.add(0, 12, 0, 10); + EXPECT_TRUE(term.evaluate()); + MatchData md(MatchData::params().numTermFields(2)); + term.unpack_match_data(23, td, md); + auto tmd0 = md.resolveTermField(0); + EXPECT_NE(23, tmd0->getDocId()); + auto tmd1 = md.resolveTermField(1); + EXPECT_EQ(23, tmd1->getDocId()); + using Weights = std::vector<int32_t>; + Weights weights; + for (auto& pos : *tmd1) { + weights.emplace_back(pos.getElementWeight()); + } + EXPECT_EQ((Weights{13, 4}), weights); +} + TEST(StreamingQueryTest, control_the_size_of_query_terms) { EXPECT_EQ(112u, sizeof(QueryTermSimple)); diff --git a/searchlib/src/vespa/searchlib/query/streaming/CMakeLists.txt b/searchlib/src/vespa/searchlib/query/streaming/CMakeLists.txt index 9b53407aff5..6b9be2e3269 100644 --- a/searchlib/src/vespa/searchlib/query/streaming/CMakeLists.txt +++ b/searchlib/src/vespa/searchlib/query/streaming/CMakeLists.txt @@ -10,5 +10,6 @@ vespa_add_library(searchlib_query_streaming OBJECT querynoderesultbase.cpp queryterm.cpp wand_term.cpp + weighted_set_term.cpp DEPENDS ) diff --git a/searchlib/src/vespa/searchlib/query/streaming/query.h b/searchlib/src/vespa/searchlib/query/streaming/query.h index 8befa2fe7fa..84c693b86d0 100644 --- a/searchlib/src/vespa/searchlib/query/streaming/query.h +++ b/searchlib/src/vespa/searchlib/query/streaming/query.h @@ -103,8 +103,7 @@ public: EquivQueryNode() noexcept : OrQueryNode("EQUIV") { } bool evaluate() const override; bool isFlattenable(ParseItem::ItemType type) const override { - return (type == ParseItem::ITEM_EQUIV) || - (type == ParseItem::ITEM_WEIGHTED_SET); + return (type == ParseItem::ITEM_EQUIV); } }; diff --git a/searchlib/src/vespa/searchlib/query/streaming/querynode.cpp b/searchlib/src/vespa/searchlib/query/streaming/querynode.cpp index c24f41d16cf..1ce80660d46 100644 --- a/searchlib/src/vespa/searchlib/query/streaming/querynode.cpp +++ b/searchlib/src/vespa/searchlib/query/streaming/querynode.cpp @@ -6,6 +6,7 @@ #include <vespa/searchlib/query/streaming/dot_product_term.h> #include <vespa/searchlib/query/streaming/in_term.h> #include <vespa/searchlib/query/streaming/wand_term.h> +#include <vespa/searchlib/query/streaming/weighted_set_term.h> #include <vespa/searchlib/query/tree/term_vector.h> #include <charconv> #include <vespa/log/log.h> @@ -40,7 +41,6 @@ QueryNode::Build(const QueryNode * parent, const QueryNodeResultFactory & factor case ParseItem::ITEM_OR: case ParseItem::ITEM_WEAK_AND: case ParseItem::ITEM_EQUIV: - case ParseItem::ITEM_WEIGHTED_SET: case ParseItem::ITEM_NOT: case ParseItem::ITEM_PHRASE: case ParseItem::ITEM_SAME_ELEMENT: @@ -55,7 +55,6 @@ QueryNode::Build(const QueryNode * parent, const QueryNodeResultFactory & factor nqn->distance(queryRep.getNearDistance()); } if ((type == ParseItem::ITEM_WEAK_AND) || - (type == ParseItem::ITEM_WEIGHTED_SET) || (type == ParseItem::ITEM_SAME_ELEMENT)) { qn->setIndex(queryRep.getIndexName()); @@ -192,6 +191,9 @@ QueryNode::Build(const QueryNode * parent, const QueryNodeResultFactory & factor case ParseItem::ITEM_WAND: qn = build_wand_term(factory, queryRep); break; + case ParseItem::ITEM_WEIGHTED_SET: + qn = build_weighted_set_term(factory, queryRep); + break; default: skip_unknown(queryRep); break; @@ -270,6 +272,16 @@ QueryNode::build_wand_term(const QueryNodeResultFactory& factory, SimpleQuerySta return wand; } +std::unique_ptr<QueryNode> +QueryNode::build_weighted_set_term(const QueryNodeResultFactory& factory, SimpleQueryStackDumpIterator& queryRep) +{ + auto ws = std::make_unique<WeightedSetTerm>(factory.create(), queryRep.getIndexName(), queryRep.getArity()); + ws->setWeight(queryRep.GetWeight()); + ws->setUniqueId(queryRep.getUniqueId()); + populate_multi_term(factory.normalizing_mode(ws->index()), *ws, queryRep); + return ws; +} + void QueryNode::skip_unknown(SimpleQueryStackDumpIterator& queryRep) { diff --git a/searchlib/src/vespa/searchlib/query/streaming/querynode.h b/searchlib/src/vespa/searchlib/query/streaming/querynode.h index a0561b2e52e..454932c0a68 100644 --- a/searchlib/src/vespa/searchlib/query/streaming/querynode.h +++ b/searchlib/src/vespa/searchlib/query/streaming/querynode.h @@ -32,6 +32,7 @@ class QueryNode static void populate_multi_term(Normalizing string_normalize_mode, MultiTerm& mt, SimpleQueryStackDumpIterator& queryRep); static std::unique_ptr<QueryNode> build_dot_product_term(const QueryNodeResultFactory& factory, SimpleQueryStackDumpIterator& queryRep); static std::unique_ptr<QueryNode> build_wand_term(const QueryNodeResultFactory& factory, SimpleQueryStackDumpIterator& queryRep); + static std::unique_ptr<QueryNode> build_weighted_set_term(const QueryNodeResultFactory& factory, SimpleQueryStackDumpIterator& queryRep); static void skip_unknown(SimpleQueryStackDumpIterator& queryRep); public: using UP = std::unique_ptr<QueryNode>; diff --git a/searchlib/src/vespa/searchlib/query/streaming/weighted_set_term.cpp b/searchlib/src/vespa/searchlib/query/streaming/weighted_set_term.cpp new file mode 100644 index 00000000000..90d0be5d43c --- /dev/null +++ b/searchlib/src/vespa/searchlib/query/streaming/weighted_set_term.cpp @@ -0,0 +1,53 @@ +// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. + +#include "weighted_set_term.h" +#include <vespa/searchlib/fef/itermdata.h> +#include <vespa/searchlib/fef/matchdata.h> +#include <vespa/vespalib/stllike/hash_map.hpp> + +using search::fef::ITermData; +using search::fef::MatchData; + +namespace search::streaming { + +WeightedSetTerm::WeightedSetTerm(std::unique_ptr<QueryNodeResultBase> result_base, const string & index, uint32_t num_terms) + : MultiTerm(std::move(result_base), index, num_terms) +{ +} + +WeightedSetTerm::~WeightedSetTerm() = default; + +void +WeightedSetTerm::unpack_match_data(uint32_t docid, const ITermData& td, MatchData& match_data) +{ + vespalib::hash_map<uint32_t,std::vector<double>> scores; + HitList hl_store; + for (const auto& term : _terms) { + auto& hl = term->evaluateHits(hl_store); + for (auto& hit : hl) { + scores[hit.context()].emplace_back(term->weight().percent()); + } + } + auto num_fields = td.numFields(); + for (uint32_t field_idx = 0; field_idx < num_fields; ++field_idx) { + auto& tfd = td.field(field_idx); + auto field_id = tfd.getFieldId(); + if (scores.contains(field_id)) { + auto handle = tfd.getHandle(); + if (handle != fef::IllegalHandle) { + auto &field_scores = scores[field_id]; + std::sort(field_scores.begin(), field_scores.end(), std::greater()); + auto tmd = match_data.resolveTermField(tfd.getHandle()); + tmd->setFieldId(field_id); + tmd->reset(docid); + for (auto& field_score : field_scores) { + fef::TermFieldMatchDataPosition pos; + pos.setElementWeight(field_score); + tmd->appendPosition(pos); + } + } + } + } +} + +} diff --git a/searchlib/src/vespa/searchlib/query/streaming/weighted_set_term.h b/searchlib/src/vespa/searchlib/query/streaming/weighted_set_term.h new file mode 100644 index 00000000000..4473e0fa45b --- /dev/null +++ b/searchlib/src/vespa/searchlib/query/streaming/weighted_set_term.h @@ -0,0 +1,20 @@ +// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. + +#pragma once + +#include "multi_term.h" + +namespace search::streaming { + +/* + * A weighted set query term for streaming search. + */ +class WeightedSetTerm : public MultiTerm { + double _score_threshold; +public: + WeightedSetTerm(std::unique_ptr<QueryNodeResultBase> result_base, const string& index, uint32_t num_terms); + ~WeightedSetTerm() override; + void unpack_match_data(uint32_t docid, const fef::ITermData& td, fef::MatchData& match_data) override; +}; + +} |