From bf4f3338864574021ce7260dd527310e84bfeb32 Mon Sep 17 00:00:00 2001 From: Tor Egge Date: Mon, 15 Jan 2024 16:04:45 +0100 Subject: Add WeightedSetTerm for streaming search. --- searchlib/src/tests/query/streaming_query_test.cpp | 37 +++++++++++++++ .../vespa/searchlib/query/streaming/CMakeLists.txt | 1 + .../src/vespa/searchlib/query/streaming/query.h | 3 +- .../vespa/searchlib/query/streaming/querynode.cpp | 16 ++++++- .../vespa/searchlib/query/streaming/querynode.h | 1 + .../query/streaming/weighted_set_term.cpp | 53 ++++++++++++++++++++++ .../searchlib/query/streaming/weighted_set_term.h | 20 ++++++++ 7 files changed, 127 insertions(+), 4 deletions(-) create mode 100644 searchlib/src/vespa/searchlib/query/streaming/weighted_set_term.cpp create mode 100644 searchlib/src/vespa/searchlib/query/streaming/weighted_set_term.h (limited to 'searchlib') diff --git a/searchlib/src/tests/query/streaming_query_test.cpp b/searchlib/src/tests/query/streaming_query_test.cpp index 7c4b7555158..52f01af3dff 100644 --- a/searchlib/src/tests/query/streaming_query_test.cpp +++ b/searchlib/src/tests/query/streaming_query_test.cpp @@ -7,6 +7,7 @@ #include #include #include +#include #include #include #include @@ -1020,6 +1021,42 @@ TEST(StreamingQueryTest, wand_term) check_wand_term(exp_wand_score_field_11 + 1, "hidden score below limit"); } +TEST(StreamingQueryTest, weighted_set_term) +{ + search::streaming::WeightedSetTerm term({}, "index", 2); + term.add_term(std::make_unique(std::unique_ptr(), "7", "", QueryTermSimple::Type::WORD)); + term.get_terms().back()->setWeight(Weight(4)); + term.add_term(std::make_unique(std::unique_ptr(), "9", "", QueryTermSimple::Type::WORD)); + term.get_terms().back()->setWeight(Weight(13)); + EXPECT_EQ(2, term.get_terms().size()); + SimpleTermData td; + td.addField(10); + td.addField(11); + td.addField(12); + td.lookupField(10)->setHandle(0); + td.lookupField(12)->setHandle(1); + EXPECT_FALSE(term.evaluate()); + auto& q0 = *term.get_terms()[0]; + q0.add(0, 11, 0, 10); + q0.add(0, 12, 0, 10); + auto& q1 = *term.get_terms()[1]; + q1.add(0, 11, 0, 10); + q1.add(0, 12, 0, 10); + EXPECT_TRUE(term.evaluate()); + MatchData md(MatchData::params().numTermFields(2)); + term.unpack_match_data(23, td, md); + auto tmd0 = md.resolveTermField(0); + EXPECT_NE(23, tmd0->getDocId()); + auto tmd1 = md.resolveTermField(1); + EXPECT_EQ(23, tmd1->getDocId()); + using Weights = std::vector; + Weights weights; + for (auto& pos : *tmd1) { + weights.emplace_back(pos.getElementWeight()); + } + EXPECT_EQ((Weights{13, 4}), weights); +} + TEST(StreamingQueryTest, control_the_size_of_query_terms) { EXPECT_EQ(112u, sizeof(QueryTermSimple)); diff --git a/searchlib/src/vespa/searchlib/query/streaming/CMakeLists.txt b/searchlib/src/vespa/searchlib/query/streaming/CMakeLists.txt index 9b53407aff5..6b9be2e3269 100644 --- a/searchlib/src/vespa/searchlib/query/streaming/CMakeLists.txt +++ b/searchlib/src/vespa/searchlib/query/streaming/CMakeLists.txt @@ -10,5 +10,6 @@ vespa_add_library(searchlib_query_streaming OBJECT querynoderesultbase.cpp queryterm.cpp wand_term.cpp + weighted_set_term.cpp DEPENDS ) diff --git a/searchlib/src/vespa/searchlib/query/streaming/query.h b/searchlib/src/vespa/searchlib/query/streaming/query.h index 8befa2fe7fa..84c693b86d0 100644 --- a/searchlib/src/vespa/searchlib/query/streaming/query.h +++ b/searchlib/src/vespa/searchlib/query/streaming/query.h @@ -103,8 +103,7 @@ public: EquivQueryNode() noexcept : OrQueryNode("EQUIV") { } bool evaluate() const override; bool isFlattenable(ParseItem::ItemType type) const override { - return (type == ParseItem::ITEM_EQUIV) || - (type == ParseItem::ITEM_WEIGHTED_SET); + return (type == ParseItem::ITEM_EQUIV); } }; diff --git a/searchlib/src/vespa/searchlib/query/streaming/querynode.cpp b/searchlib/src/vespa/searchlib/query/streaming/querynode.cpp index c24f41d16cf..1ce80660d46 100644 --- a/searchlib/src/vespa/searchlib/query/streaming/querynode.cpp +++ b/searchlib/src/vespa/searchlib/query/streaming/querynode.cpp @@ -6,6 +6,7 @@ #include #include #include +#include #include #include #include @@ -40,7 +41,6 @@ QueryNode::Build(const QueryNode * parent, const QueryNodeResultFactory & factor case ParseItem::ITEM_OR: case ParseItem::ITEM_WEAK_AND: case ParseItem::ITEM_EQUIV: - case ParseItem::ITEM_WEIGHTED_SET: case ParseItem::ITEM_NOT: case ParseItem::ITEM_PHRASE: case ParseItem::ITEM_SAME_ELEMENT: @@ -55,7 +55,6 @@ QueryNode::Build(const QueryNode * parent, const QueryNodeResultFactory & factor nqn->distance(queryRep.getNearDistance()); } if ((type == ParseItem::ITEM_WEAK_AND) || - (type == ParseItem::ITEM_WEIGHTED_SET) || (type == ParseItem::ITEM_SAME_ELEMENT)) { qn->setIndex(queryRep.getIndexName()); @@ -192,6 +191,9 @@ QueryNode::Build(const QueryNode * parent, const QueryNodeResultFactory & factor case ParseItem::ITEM_WAND: qn = build_wand_term(factory, queryRep); break; + case ParseItem::ITEM_WEIGHTED_SET: + qn = build_weighted_set_term(factory, queryRep); + break; default: skip_unknown(queryRep); break; @@ -270,6 +272,16 @@ QueryNode::build_wand_term(const QueryNodeResultFactory& factory, SimpleQuerySta return wand; } +std::unique_ptr +QueryNode::build_weighted_set_term(const QueryNodeResultFactory& factory, SimpleQueryStackDumpIterator& queryRep) +{ + auto ws = std::make_unique(factory.create(), queryRep.getIndexName(), queryRep.getArity()); + ws->setWeight(queryRep.GetWeight()); + ws->setUniqueId(queryRep.getUniqueId()); + populate_multi_term(factory.normalizing_mode(ws->index()), *ws, queryRep); + return ws; +} + void QueryNode::skip_unknown(SimpleQueryStackDumpIterator& queryRep) { diff --git a/searchlib/src/vespa/searchlib/query/streaming/querynode.h b/searchlib/src/vespa/searchlib/query/streaming/querynode.h index a0561b2e52e..454932c0a68 100644 --- a/searchlib/src/vespa/searchlib/query/streaming/querynode.h +++ b/searchlib/src/vespa/searchlib/query/streaming/querynode.h @@ -32,6 +32,7 @@ class QueryNode static void populate_multi_term(Normalizing string_normalize_mode, MultiTerm& mt, SimpleQueryStackDumpIterator& queryRep); static std::unique_ptr build_dot_product_term(const QueryNodeResultFactory& factory, SimpleQueryStackDumpIterator& queryRep); static std::unique_ptr build_wand_term(const QueryNodeResultFactory& factory, SimpleQueryStackDumpIterator& queryRep); + static std::unique_ptr build_weighted_set_term(const QueryNodeResultFactory& factory, SimpleQueryStackDumpIterator& queryRep); static void skip_unknown(SimpleQueryStackDumpIterator& queryRep); public: using UP = std::unique_ptr; diff --git a/searchlib/src/vespa/searchlib/query/streaming/weighted_set_term.cpp b/searchlib/src/vespa/searchlib/query/streaming/weighted_set_term.cpp new file mode 100644 index 00000000000..90d0be5d43c --- /dev/null +++ b/searchlib/src/vespa/searchlib/query/streaming/weighted_set_term.cpp @@ -0,0 +1,53 @@ +// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. + +#include "weighted_set_term.h" +#include +#include +#include + +using search::fef::ITermData; +using search::fef::MatchData; + +namespace search::streaming { + +WeightedSetTerm::WeightedSetTerm(std::unique_ptr result_base, const string & index, uint32_t num_terms) + : MultiTerm(std::move(result_base), index, num_terms) +{ +} + +WeightedSetTerm::~WeightedSetTerm() = default; + +void +WeightedSetTerm::unpack_match_data(uint32_t docid, const ITermData& td, MatchData& match_data) +{ + vespalib::hash_map> scores; + HitList hl_store; + for (const auto& term : _terms) { + auto& hl = term->evaluateHits(hl_store); + for (auto& hit : hl) { + scores[hit.context()].emplace_back(term->weight().percent()); + } + } + auto num_fields = td.numFields(); + for (uint32_t field_idx = 0; field_idx < num_fields; ++field_idx) { + auto& tfd = td.field(field_idx); + auto field_id = tfd.getFieldId(); + if (scores.contains(field_id)) { + auto handle = tfd.getHandle(); + if (handle != fef::IllegalHandle) { + auto &field_scores = scores[field_id]; + std::sort(field_scores.begin(), field_scores.end(), std::greater()); + auto tmd = match_data.resolveTermField(tfd.getHandle()); + tmd->setFieldId(field_id); + tmd->reset(docid); + for (auto& field_score : field_scores) { + fef::TermFieldMatchDataPosition pos; + pos.setElementWeight(field_score); + tmd->appendPosition(pos); + } + } + } + } +} + +} diff --git a/searchlib/src/vespa/searchlib/query/streaming/weighted_set_term.h b/searchlib/src/vespa/searchlib/query/streaming/weighted_set_term.h new file mode 100644 index 00000000000..4473e0fa45b --- /dev/null +++ b/searchlib/src/vespa/searchlib/query/streaming/weighted_set_term.h @@ -0,0 +1,20 @@ +// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. + +#pragma once + +#include "multi_term.h" + +namespace search::streaming { + +/* + * A weighted set query term for streaming search. + */ +class WeightedSetTerm : public MultiTerm { + double _score_threshold; +public: + WeightedSetTerm(std::unique_ptr result_base, const string& index, uint32_t num_terms); + ~WeightedSetTerm() override; + void unpack_match_data(uint32_t docid, const fef::ITermData& td, fef::MatchData& match_data) override; +}; + +} -- cgit v1.2.3