diff options
author | Geir Storli <geirst@yahooinc.com> | 2023-04-25 16:20:45 +0000 |
---|---|---|
committer | Geir Storli <geirst@yahooinc.com> | 2023-04-25 16:20:45 +0000 |
commit | 1d6fcfb5c5b7399cd33c32ffea30fd9208ec000b (patch) | |
tree | 4de5f00fa4e84e779d9ae01b766a6a0a6a61416f /streamingvisitors/src/tests | |
parent | f3ac0e360e47778eb51e3619825f09e52d3b6082 (diff) |
Use targetHits in nearestNeighbor streaming searcher.
A distance heap is used to limit the number of produced document matches.
Diffstat (limited to 'streamingvisitors/src/tests')
-rw-r--r-- | streamingvisitors/src/tests/nearest_neighbor_field_searcher/nearest_neighbor_field_searcher_test.cpp | 61 | ||||
-rw-r--r-- | streamingvisitors/src/tests/rank_processor/rank_processor_test.cpp | 12 |
2 files changed, 59 insertions, 14 deletions
diff --git a/streamingvisitors/src/tests/nearest_neighbor_field_searcher/nearest_neighbor_field_searcher_test.cpp b/streamingvisitors/src/tests/nearest_neighbor_field_searcher/nearest_neighbor_field_searcher_test.cpp index 43c77398be8..b64d477fd4c 100644 --- a/streamingvisitors/src/tests/nearest_neighbor_field_searcher/nearest_neighbor_field_searcher_test.cpp +++ b/streamingvisitors/src/tests/nearest_neighbor_field_searcher/nearest_neighbor_field_searcher_test.cpp @@ -31,9 +31,11 @@ struct MockQuery { std::vector<std::unique_ptr<NearestNeighborQueryNode>> nodes; QueryTermList term_list; MockQuery& add(const vespalib::string& query_tensor_name, + uint32_t target_hits, double distance_threshold) { std::unique_ptr<QueryNodeResultBase> base; - auto node = std::make_unique<NearestNeighborQueryNode>(std::move(base), query_tensor_name, "my_tensor_field", 7, search::query::Weight(11), distance_threshold); + auto node = std::make_unique<NearestNeighborQueryNode>(std::move(base), query_tensor_name, "my_tensor_field", + target_hits, distance_threshold, 7, search::query::Weight(100)); nodes.push_back(std::move(node)); term_list.push_back(nodes.back().get()); return *this; @@ -90,34 +92,71 @@ public: query.reset(); searcher.onValue(fv); } + void expect_match(const vespalib::string& spec_expr, double exp_square_distance, const NearestNeighborQueryNode& node) { + match(spec_expr); + expect_match(exp_square_distance, node); + } void expect_match(double exp_square_distance, const NearestNeighborQueryNode& node) { double exp_raw_score = dist_func.to_rawscore(exp_square_distance); EXPECT_TRUE(node.evaluate()); + EXPECT_DOUBLE_EQ(exp_square_distance, node.get_distance().value()); EXPECT_DOUBLE_EQ(exp_raw_score, node.get_raw_score().value()); } + void expect_not_match(const vespalib::string& spec_expr, const NearestNeighborQueryNode& node) { + match(spec_expr); + EXPECT_FALSE(node.evaluate()); + } }; -TEST_F(NearestNeighborSearcherTest, raw_score_calculated_with_distance_threshold) +TEST_F(NearestNeighborSearcherTest, distance_heap_keeps_the_best_target_hits) { - query.add("qt1", 3); + query.add("qt1", 2, 100.0); + const auto& node = query.get(0); set_query_tensor("qt1", "tensor(x[2]):[1,3]"); prepare(); - match("tensor(x[2]):[1,5]"); - expect_match((5-3)*(5-3), query.get(0)); + expect_match("tensor(x[2]):[1,7]", (7-3)*(7-3), node); + expect_match("tensor(x[2]):[1,9]", (9-3)*(9-3), node); - match("tensor(x[2]):[1,6]"); - expect_match((6-3)*(6-3), query.get(0)); + // The distance limit is now (9-3)*(9-3) = 36, so this is not good enough. + expect_not_match("tensor(x[2]):[1,10]", node); + + expect_match("tensor(x[2]):[1,5]", (5-3)*(5-3), node); + + // The distance limit is now (7-3)*(7-3) = 16, so this is not good enough. + expect_not_match("tensor(x[2]):[1,8]", node); + + // This is not considered a document match as get_raw_score() is not called, + // and the distance heap is not updated. + match("tensor(x[2]):[1,4]"); + EXPECT_EQ(1, node.get_distance().value()); + EXPECT_TRUE(node.evaluate()); + + // The distance limit is still (7-3)*(7-3) = 16, so this is in fact good enough. + expect_match("tensor(x[2]):[1,6]", (6-3)*(6-3), node); + + // The distance limit is (6-3)*(6-3) = 4, and a similar distance is a match. + expect_match("tensor(x[2]):[1,6]", (6-3)*(6-3), node); +} + +TEST_F(NearestNeighborSearcherTest, raw_score_calculated_with_distance_threshold) +{ + query.add("qt1", 10, 3.0); + const auto& node = query.get(0); + set_query_tensor("qt1", "tensor(x[2]):[1,3]"); + prepare(); + + expect_match("tensor(x[2]):[1,5]", (5-3)*(5-3), node); + expect_match("tensor(x[2]):[1,6]", (6-3)*(6-3), node); - match("tensor(x[2]):[1,7]"); // This is not a match since ((7-3)*(7-3) = 16) is larger than the internal distance threshold of (3*3 = 9). - EXPECT_FALSE(query.get(0).evaluate()); + expect_not_match("tensor(x[2]):[1,7]", node); } TEST_F(NearestNeighborSearcherTest, raw_score_calculated_for_two_query_operators) { - query.add("qt1", 3); - query.add("qt2", 4); + query.add("qt1", 10, 3.0); + query.add("qt2", 10, 4.0); set_query_tensor("qt1", "tensor(x[2]):[1,3]"); set_query_tensor("qt2", "tensor(x[2]):[1,4]"); prepare(); diff --git a/streamingvisitors/src/tests/rank_processor/rank_processor_test.cpp b/streamingvisitors/src/tests/rank_processor/rank_processor_test.cpp index 9f3f3d770e4..4d425d9dedd 100644 --- a/streamingvisitors/src/tests/rank_processor/rank_processor_test.cpp +++ b/streamingvisitors/src/tests/rank_processor/rank_processor_test.cpp @@ -55,6 +55,10 @@ RankProcessorTest::build_query(QueryBuilder<SimpleQueryNodeTypes> &builder) _query_wrapper = std::make_unique<QueryWrapper>(*_query); } +class MockRawScoreCalculator : public search::streaming::NearestNeighborQueryNode::RawScoreCalculator { +public: + double to_raw_score(double distance) override { return distance * 2; } +}; TEST_F(RankProcessorTest, unpack_match_data_for_nearest_neighbor_query_node) { @@ -71,6 +75,8 @@ TEST_F(RankProcessorTest, unpack_match_data_for_nearest_neighbor_query_node) EXPECT_EQ(1u, term_list.size()); auto node = dynamic_cast<NearestNeighborQueryNode*>(term_list.front().getTerm()); EXPECT_NE(nullptr, node); + MockRawScoreCalculator calc; + node->set_raw_score_calc(&calc); auto& qtd = static_cast<QueryTermData &>(node->getQueryItem()); auto& td = qtd.getTermData(); constexpr TermFieldHandle handle = 27; @@ -82,11 +88,11 @@ TEST_F(RankProcessorTest, unpack_match_data_for_nearest_neighbor_query_node) EXPECT_EQ(invalid_id, tfmd->getDocId()); RankProcessor::unpack_match_data(1, *md, *_query_wrapper); EXPECT_EQ(invalid_id, tfmd->getDocId()); - constexpr double raw_score = 1.5; - node->set_raw_score(raw_score); + constexpr double distance = 1.5; + node->set_distance(distance); RankProcessor::unpack_match_data(2, *md, *_query_wrapper); EXPECT_EQ(2, tfmd->getDocId()); - EXPECT_EQ(raw_score, tfmd->getRawScore()); + EXPECT_EQ(distance * 2, tfmd->getRawScore()); node->reset(); RankProcessor::unpack_match_data(3, *md, *_query_wrapper); EXPECT_EQ(2, tfmd->getDocId()); |