aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorTor Egge <Tor.Egge@online.no>2024-01-15 16:04:45 +0100
committerTor Egge <Tor.Egge@online.no>2024-01-15 16:04:45 +0100
commitbf4f3338864574021ce7260dd527310e84bfeb32 (patch)
treef2ebcad28e51e55e13a871c52a47e36b72648627
parent008338e2cfc4b455a07b8a8ab129b8cbbf9d0f2f (diff)
Add WeightedSetTerm for streaming search.
-rw-r--r--searchlib/src/tests/query/streaming_query_test.cpp37
-rw-r--r--searchlib/src/vespa/searchlib/query/streaming/CMakeLists.txt1
-rw-r--r--searchlib/src/vespa/searchlib/query/streaming/query.h3
-rw-r--r--searchlib/src/vespa/searchlib/query/streaming/querynode.cpp16
-rw-r--r--searchlib/src/vespa/searchlib/query/streaming/querynode.h1
-rw-r--r--searchlib/src/vespa/searchlib/query/streaming/weighted_set_term.cpp53
-rw-r--r--searchlib/src/vespa/searchlib/query/streaming/weighted_set_term.h20
7 files changed, 127 insertions, 4 deletions
diff --git a/searchlib/src/tests/query/streaming_query_test.cpp b/searchlib/src/tests/query/streaming_query_test.cpp
index 7c4b7555158..52f01af3dff 100644
--- a/searchlib/src/tests/query/streaming_query_test.cpp
+++ b/searchlib/src/tests/query/streaming_query_test.cpp
@@ -7,6 +7,7 @@
#include <vespa/searchlib/query/streaming/query.h>
#include <vespa/searchlib/query/streaming/nearest_neighbor_query_node.h>
#include <vespa/searchlib/query/streaming/wand_term.h>
+#include <vespa/searchlib/query/streaming/weighted_set_term.h>
#include <vespa/searchlib/query/tree/querybuilder.h>
#include <vespa/searchlib/query/tree/simplequery.h>
#include <vespa/searchlib/query/tree/stackdumpcreator.h>
@@ -1020,6 +1021,42 @@ TEST(StreamingQueryTest, wand_term)
check_wand_term(exp_wand_score_field_11 + 1, "hidden score below limit");
}
+TEST(StreamingQueryTest, weighted_set_term)
+{
+ search::streaming::WeightedSetTerm term({}, "index", 2);
+ term.add_term(std::make_unique<QueryTerm>(std::unique_ptr<QueryNodeResultBase>(), "7", "", QueryTermSimple::Type::WORD));
+ term.get_terms().back()->setWeight(Weight(4));
+ term.add_term(std::make_unique<QueryTerm>(std::unique_ptr<QueryNodeResultBase>(), "9", "", QueryTermSimple::Type::WORD));
+ term.get_terms().back()->setWeight(Weight(13));
+ EXPECT_EQ(2, term.get_terms().size());
+ SimpleTermData td;
+ td.addField(10);
+ td.addField(11);
+ td.addField(12);
+ td.lookupField(10)->setHandle(0);
+ td.lookupField(12)->setHandle(1);
+ EXPECT_FALSE(term.evaluate());
+ auto& q0 = *term.get_terms()[0];
+ q0.add(0, 11, 0, 10);
+ q0.add(0, 12, 0, 10);
+ auto& q1 = *term.get_terms()[1];
+ q1.add(0, 11, 0, 10);
+ q1.add(0, 12, 0, 10);
+ EXPECT_TRUE(term.evaluate());
+ MatchData md(MatchData::params().numTermFields(2));
+ term.unpack_match_data(23, td, md);
+ auto tmd0 = md.resolveTermField(0);
+ EXPECT_NE(23, tmd0->getDocId());
+ auto tmd1 = md.resolveTermField(1);
+ EXPECT_EQ(23, tmd1->getDocId());
+ using Weights = std::vector<int32_t>;
+ Weights weights;
+ for (auto& pos : *tmd1) {
+ weights.emplace_back(pos.getElementWeight());
+ }
+ EXPECT_EQ((Weights{13, 4}), weights);
+}
+
TEST(StreamingQueryTest, control_the_size_of_query_terms)
{
EXPECT_EQ(112u, sizeof(QueryTermSimple));
diff --git a/searchlib/src/vespa/searchlib/query/streaming/CMakeLists.txt b/searchlib/src/vespa/searchlib/query/streaming/CMakeLists.txt
index 9b53407aff5..6b9be2e3269 100644
--- a/searchlib/src/vespa/searchlib/query/streaming/CMakeLists.txt
+++ b/searchlib/src/vespa/searchlib/query/streaming/CMakeLists.txt
@@ -10,5 +10,6 @@ vespa_add_library(searchlib_query_streaming OBJECT
querynoderesultbase.cpp
queryterm.cpp
wand_term.cpp
+ weighted_set_term.cpp
DEPENDS
)
diff --git a/searchlib/src/vespa/searchlib/query/streaming/query.h b/searchlib/src/vespa/searchlib/query/streaming/query.h
index 8befa2fe7fa..84c693b86d0 100644
--- a/searchlib/src/vespa/searchlib/query/streaming/query.h
+++ b/searchlib/src/vespa/searchlib/query/streaming/query.h
@@ -103,8 +103,7 @@ public:
EquivQueryNode() noexcept : OrQueryNode("EQUIV") { }
bool evaluate() const override;
bool isFlattenable(ParseItem::ItemType type) const override {
- return (type == ParseItem::ITEM_EQUIV) ||
- (type == ParseItem::ITEM_WEIGHTED_SET);
+ return (type == ParseItem::ITEM_EQUIV);
}
};
diff --git a/searchlib/src/vespa/searchlib/query/streaming/querynode.cpp b/searchlib/src/vespa/searchlib/query/streaming/querynode.cpp
index c24f41d16cf..1ce80660d46 100644
--- a/searchlib/src/vespa/searchlib/query/streaming/querynode.cpp
+++ b/searchlib/src/vespa/searchlib/query/streaming/querynode.cpp
@@ -6,6 +6,7 @@
#include <vespa/searchlib/query/streaming/dot_product_term.h>
#include <vespa/searchlib/query/streaming/in_term.h>
#include <vespa/searchlib/query/streaming/wand_term.h>
+#include <vespa/searchlib/query/streaming/weighted_set_term.h>
#include <vespa/searchlib/query/tree/term_vector.h>
#include <charconv>
#include <vespa/log/log.h>
@@ -40,7 +41,6 @@ QueryNode::Build(const QueryNode * parent, const QueryNodeResultFactory & factor
case ParseItem::ITEM_OR:
case ParseItem::ITEM_WEAK_AND:
case ParseItem::ITEM_EQUIV:
- case ParseItem::ITEM_WEIGHTED_SET:
case ParseItem::ITEM_NOT:
case ParseItem::ITEM_PHRASE:
case ParseItem::ITEM_SAME_ELEMENT:
@@ -55,7 +55,6 @@ QueryNode::Build(const QueryNode * parent, const QueryNodeResultFactory & factor
nqn->distance(queryRep.getNearDistance());
}
if ((type == ParseItem::ITEM_WEAK_AND) ||
- (type == ParseItem::ITEM_WEIGHTED_SET) ||
(type == ParseItem::ITEM_SAME_ELEMENT))
{
qn->setIndex(queryRep.getIndexName());
@@ -192,6 +191,9 @@ QueryNode::Build(const QueryNode * parent, const QueryNodeResultFactory & factor
case ParseItem::ITEM_WAND:
qn = build_wand_term(factory, queryRep);
break;
+ case ParseItem::ITEM_WEIGHTED_SET:
+ qn = build_weighted_set_term(factory, queryRep);
+ break;
default:
skip_unknown(queryRep);
break;
@@ -270,6 +272,16 @@ QueryNode::build_wand_term(const QueryNodeResultFactory& factory, SimpleQuerySta
return wand;
}
+std::unique_ptr<QueryNode>
+QueryNode::build_weighted_set_term(const QueryNodeResultFactory& factory, SimpleQueryStackDumpIterator& queryRep)
+{
+ auto ws = std::make_unique<WeightedSetTerm>(factory.create(), queryRep.getIndexName(), queryRep.getArity());
+ ws->setWeight(queryRep.GetWeight());
+ ws->setUniqueId(queryRep.getUniqueId());
+ populate_multi_term(factory.normalizing_mode(ws->index()), *ws, queryRep);
+ return ws;
+}
+
void
QueryNode::skip_unknown(SimpleQueryStackDumpIterator& queryRep)
{
diff --git a/searchlib/src/vespa/searchlib/query/streaming/querynode.h b/searchlib/src/vespa/searchlib/query/streaming/querynode.h
index a0561b2e52e..454932c0a68 100644
--- a/searchlib/src/vespa/searchlib/query/streaming/querynode.h
+++ b/searchlib/src/vespa/searchlib/query/streaming/querynode.h
@@ -32,6 +32,7 @@ class QueryNode
static void populate_multi_term(Normalizing string_normalize_mode, MultiTerm& mt, SimpleQueryStackDumpIterator& queryRep);
static std::unique_ptr<QueryNode> build_dot_product_term(const QueryNodeResultFactory& factory, SimpleQueryStackDumpIterator& queryRep);
static std::unique_ptr<QueryNode> build_wand_term(const QueryNodeResultFactory& factory, SimpleQueryStackDumpIterator& queryRep);
+ static std::unique_ptr<QueryNode> build_weighted_set_term(const QueryNodeResultFactory& factory, SimpleQueryStackDumpIterator& queryRep);
static void skip_unknown(SimpleQueryStackDumpIterator& queryRep);
public:
using UP = std::unique_ptr<QueryNode>;
diff --git a/searchlib/src/vespa/searchlib/query/streaming/weighted_set_term.cpp b/searchlib/src/vespa/searchlib/query/streaming/weighted_set_term.cpp
new file mode 100644
index 00000000000..90d0be5d43c
--- /dev/null
+++ b/searchlib/src/vespa/searchlib/query/streaming/weighted_set_term.cpp
@@ -0,0 +1,53 @@
+// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+
+#include "weighted_set_term.h"
+#include <vespa/searchlib/fef/itermdata.h>
+#include <vespa/searchlib/fef/matchdata.h>
+#include <vespa/vespalib/stllike/hash_map.hpp>
+
+using search::fef::ITermData;
+using search::fef::MatchData;
+
+namespace search::streaming {
+
+WeightedSetTerm::WeightedSetTerm(std::unique_ptr<QueryNodeResultBase> result_base, const string & index, uint32_t num_terms)
+ : MultiTerm(std::move(result_base), index, num_terms)
+{
+}
+
+WeightedSetTerm::~WeightedSetTerm() = default;
+
+void
+WeightedSetTerm::unpack_match_data(uint32_t docid, const ITermData& td, MatchData& match_data)
+{
+ vespalib::hash_map<uint32_t,std::vector<double>> scores;
+ HitList hl_store;
+ for (const auto& term : _terms) {
+ auto& hl = term->evaluateHits(hl_store);
+ for (auto& hit : hl) {
+ scores[hit.context()].emplace_back(term->weight().percent());
+ }
+ }
+ auto num_fields = td.numFields();
+ for (uint32_t field_idx = 0; field_idx < num_fields; ++field_idx) {
+ auto& tfd = td.field(field_idx);
+ auto field_id = tfd.getFieldId();
+ if (scores.contains(field_id)) {
+ auto handle = tfd.getHandle();
+ if (handle != fef::IllegalHandle) {
+ auto &field_scores = scores[field_id];
+ std::sort(field_scores.begin(), field_scores.end(), std::greater());
+ auto tmd = match_data.resolveTermField(tfd.getHandle());
+ tmd->setFieldId(field_id);
+ tmd->reset(docid);
+ for (auto& field_score : field_scores) {
+ fef::TermFieldMatchDataPosition pos;
+ pos.setElementWeight(field_score);
+ tmd->appendPosition(pos);
+ }
+ }
+ }
+ }
+}
+
+}
diff --git a/searchlib/src/vespa/searchlib/query/streaming/weighted_set_term.h b/searchlib/src/vespa/searchlib/query/streaming/weighted_set_term.h
new file mode 100644
index 00000000000..4473e0fa45b
--- /dev/null
+++ b/searchlib/src/vespa/searchlib/query/streaming/weighted_set_term.h
@@ -0,0 +1,20 @@
+// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+
+#pragma once
+
+#include "multi_term.h"
+
+namespace search::streaming {
+
+/*
+ * A weighted set query term for streaming search.
+ */
+class WeightedSetTerm : public MultiTerm {
+ double _score_threshold;
+public:
+ WeightedSetTerm(std::unique_ptr<QueryNodeResultBase> result_base, const string& index, uint32_t num_terms);
+ ~WeightedSetTerm() override;
+ void unpack_match_data(uint32_t docid, const fef::ITermData& td, fef::MatchData& match_data) override;
+};
+
+}