diff options
Diffstat (limited to 'searchlib/src')
4 files changed, 62 insertions, 17 deletions
diff --git a/searchlib/src/tests/query/streaming_query_test.cpp b/searchlib/src/tests/query/streaming_query_test.cpp index f354f635def..2c202d9131b 100644 --- a/searchlib/src/tests/query/streaming_query_test.cpp +++ b/searchlib/src/tests/query/streaming_query_test.cpp @@ -1,6 +1,7 @@ // Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. #include <vespa/searchlib/query/streaming/query.h> +#include <vespa/searchlib/query/streaming/nearest_neighbor_query_node.h> #include <vespa/searchlib/query/tree/querybuilder.h> #include <vespa/searchlib/query/tree/simplequery.h> #include <vespa/searchlib/query/tree/stackdumpcreator.h> @@ -804,6 +805,42 @@ TEST("testSameElementEvaluate") { EXPECT_TRUE(sameElem->evaluate()); } +TEST("test_nearest_neighbor_query_node") +{ + QueryBuilder<SimpleQueryNodeTypes> builder; + constexpr double distance_threshold = 35.5; + constexpr int32_t id = 42; + constexpr int32_t weight = 1; + constexpr uint32_t target_num_hits = 100; + constexpr bool allow_approximate = false; + constexpr uint32_t explore_additional_hits = 800; + constexpr double raw_score = 0.5; + builder.add_nearest_neighbor_term("qtensor", "field", id, Weight(weight), target_num_hits, allow_approximate, explore_additional_hits, distance_threshold); + auto build_node = builder.build(); + auto stack_dump = StackDumpCreator::create(*build_node); + QueryNodeResultFactory empty; + Query q(empty, stack_dump); + auto* qterm = dynamic_cast<QueryTerm *>(&q.getRoot()); + EXPECT_TRUE(qterm != nullptr); + auto* node = dynamic_cast<NearestNeighborQueryNode *>(&q.getRoot()); + EXPECT_TRUE(node != nullptr); + EXPECT_EQUAL(node, qterm->as_nearest_neighbor_query_node()); + EXPECT_EQUAL("qtensor", node->get_query_tensor_name()); + EXPECT_EQUAL("field", node->getIndex()); + EXPECT_EQUAL(id, static_cast<int32_t>(node->uniqueId())); + EXPECT_EQUAL(weight, node->weight().percent()); + EXPECT_EQUAL(distance_threshold, node->get_distance_threshold()); + EXPECT_FALSE(node->get_raw_score().has_value()); + EXPECT_FALSE(node->evaluate()); + node->set_raw_score(raw_score); + EXPECT_TRUE(node->get_raw_score().has_value()); + EXPECT_EQUAL(raw_score, node->get_raw_score().value()); + EXPECT_TRUE(node->evaluate()); + node->reset(); + EXPECT_FALSE(node->get_raw_score().has_value()); + EXPECT_FALSE(node->evaluate()); +} + TEST("Control the size of query terms") { EXPECT_EQUAL(112u, sizeof(QueryTermSimple)); EXPECT_EQUAL(128u, sizeof(QueryTermUCS4)); diff --git a/searchlib/src/vespa/searchlib/attribute/attributevector.cpp b/searchlib/src/vespa/searchlib/attribute/attributevector.cpp index 9110c08099a..f4ab447ed51 100644 --- a/searchlib/src/vespa/searchlib/attribute/attributevector.cpp +++ b/searchlib/src/vespa/searchlib/attribute/attributevector.cpp @@ -453,19 +453,6 @@ AttributeVector::set_reserved_doc_values() return; } clearDoc(docId); - if (hasMultiValue()) { - if (isFloatingPointType()) { - auto * vec = dynamic_cast<FloatingPointAttribute *>(this); - bool appendedUndefined = vec->append(0, attribute::getUndefined<double>(), 1); - assert(appendedUndefined); - (void) appendedUndefined; - } else if (isStringType()) { - auto * vec = dynamic_cast<StringAttribute *>(this); - bool appendedUndefined = vec->append(0, StringAttribute::defaultValue(), 1); - assert(appendedUndefined); - (void) appendedUndefined; - } - } commit(); } diff --git a/searchlib/src/vespa/searchlib/query/streaming/nearest_neighbor_query_node.cpp b/searchlib/src/vespa/searchlib/query/streaming/nearest_neighbor_query_node.cpp index fdc513f9617..d1c37cd6dcd 100644 --- a/searchlib/src/vespa/searchlib/query/streaming/nearest_neighbor_query_node.cpp +++ b/searchlib/src/vespa/searchlib/query/streaming/nearest_neighbor_query_node.cpp @@ -6,7 +6,8 @@ namespace search::streaming { NearestNeighborQueryNode::NearestNeighborQueryNode(std::unique_ptr<QueryNodeResultBase> resultBase, const string& term, const string& index, int32_t id, search::query::Weight weight, double distance_threshold) : QueryTerm(std::move(resultBase), term, index, Type::NEAREST_NEIGHBOR), - _distance_threshold(distance_threshold) + _distance_threshold(distance_threshold), + _raw_score() { setUniqueId(id); setWeight(weight); @@ -14,6 +15,18 @@ NearestNeighborQueryNode::NearestNeighborQueryNode(std::unique_ptr<QueryNodeResu NearestNeighborQueryNode::~NearestNeighborQueryNode() = default; +bool +NearestNeighborQueryNode::evaluate() const +{ + return _raw_score.has_value(); +} + +void +NearestNeighborQueryNode::reset() +{ + _raw_score.reset(); +} + NearestNeighborQueryNode* NearestNeighborQueryNode::as_nearest_neighbor_query_node() noexcept { diff --git a/searchlib/src/vespa/searchlib/query/streaming/nearest_neighbor_query_node.h b/searchlib/src/vespa/searchlib/query/streaming/nearest_neighbor_query_node.h index ddc84a4b6d3..0beb130c53d 100644 --- a/searchlib/src/vespa/searchlib/query/streaming/nearest_neighbor_query_node.h +++ b/searchlib/src/vespa/searchlib/query/streaming/nearest_neighbor_query_node.h @@ -3,15 +3,19 @@ #pragma once #include "queryterm.h" +#include <optional> namespace search::streaming { /* * Nearest neighbor query node. */ -class NearestNeighborQueryNode: public QueryTerm -{ - double _distance_threshold; +class NearestNeighborQueryNode: public QueryTerm { +private: + double _distance_threshold; + // When this value is set it also indicates a match + std::optional<double> _raw_score; + public: NearestNeighborQueryNode(std::unique_ptr<QueryNodeResultBase> resultBase, const string& term, const string& index, int32_t id, search::query::Weight weight, double distance_threshold); NearestNeighborQueryNode(const NearestNeighborQueryNode &) = delete; @@ -19,9 +23,13 @@ public: NearestNeighborQueryNode(NearestNeighborQueryNode &&) = delete; NearestNeighborQueryNode & operator = (NearestNeighborQueryNode &&) = delete; ~NearestNeighborQueryNode() override; + bool evaluate() const override; + void reset() override; NearestNeighborQueryNode* as_nearest_neighbor_query_node() noexcept override; const vespalib::string& get_query_tensor_name() const { return getTermString(); } double get_distance_threshold() const { return _distance_threshold; } + void set_raw_score(double value) { _raw_score = value; } + const std::optional<double>& get_raw_score() const noexcept { return _raw_score; } }; } |