diff options
author | Arne H Juul <arnej27959@users.noreply.github.com> | 2020-02-20 15:19:02 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2020-02-20 15:19:02 +0100 |
commit | 3ab530c774be833992bbec327dfd43a5ee7fa33a (patch) | |
tree | e308374a6b722d7eb01fce3906ec604e4df0c887 /searchlib | |
parent | 8b9ddc1fea064f2851f540b9fdeff94d12c8ffa4 (diff) | |
parent | 97ebc6ec54db9ea2005eb6cd958c3ce3c76cde63 (diff) |
Merge pull request #12281 from vespa-engine/arnej/add-nns-iterator
Arnej/add nns iterator
Diffstat (limited to 'searchlib')
10 files changed, 187 insertions, 9 deletions
diff --git a/searchlib/src/tests/queryeval/nearest_neighbor/nearest_neighbor_test.cpp b/searchlib/src/tests/queryeval/nearest_neighbor/nearest_neighbor_test.cpp index 7bc582ab442..691e80aeb9f 100644 --- a/searchlib/src/tests/queryeval/nearest_neighbor/nearest_neighbor_test.cpp +++ b/searchlib/src/tests/queryeval/nearest_neighbor/nearest_neighbor_test.cpp @@ -12,6 +12,7 @@ #include <vespa/searchlib/queryeval/simpleresult.h> #include <vespa/searchlib/tensor/dense_tensor_attribute.h> #include <vespa/vespalib/test/insertion_operators.h> +#include <vespa/searchlib/queryeval/nns_index_iterator.h> #include <vespa/log/log.h> LOG_SETUP("nearest_neighbor_test"); @@ -190,4 +191,70 @@ TEST("require that NearestNeighborIterator sets expected rawscore") { TEST_DO(verify_iterator_sets_expected_rawscore(denseSpecFloat, denseSpecDouble)); } +TEST("require that NnsIndexIterator works as expected") { + std::vector<NnsIndexIterator::Hit> hits{{2,4.0}, {3,9.0}, {5,1.0}, {8,16.0}, {9,36.0}}; + auto md = MatchData::makeTestInstance(2, 2); + auto &tfmd = *(md->resolveTermField(0)); + auto search = NnsIndexIterator::create(true, tfmd, hits); + uint32_t docid = 1; + search->initFullRange(); + bool match = search->seek(docid); + EXPECT_FALSE(match); + EXPECT_FALSE(search->isAtEnd()); + EXPECT_EQUAL(2u, search->getDocId()); + docid = 2; + match = search->seek(docid); + EXPECT_TRUE(match); + EXPECT_FALSE(search->isAtEnd()); + EXPECT_EQUAL(docid, search->getDocId()); + search->unpack(docid); + EXPECT_EQUAL(2.0, tfmd.getRawScore()); + + docid = 3; + match = search->seek(docid); + EXPECT_TRUE(match); + EXPECT_FALSE(search->isAtEnd()); + EXPECT_EQUAL(docid, search->getDocId()); + search->unpack(docid); + EXPECT_EQUAL(3.0, tfmd.getRawScore()); + + docid = 4; + match = search->seek(docid); + EXPECT_FALSE(match); + EXPECT_FALSE(search->isAtEnd()); + EXPECT_EQUAL(5u, search->getDocId()); + + docid = 6; + match = search->seek(docid); + EXPECT_FALSE(match); + EXPECT_FALSE(search->isAtEnd()); + EXPECT_EQUAL(8u, search->getDocId()); + docid = 8; + search->unpack(docid); + EXPECT_EQUAL(4.0, tfmd.getRawScore()); + docid = 9; + match = search->seek(docid); + EXPECT_TRUE(match); + EXPECT_FALSE(search->isAtEnd()); + docid = 10; + match = search->seek(docid); + EXPECT_FALSE(match); + EXPECT_TRUE(search->isAtEnd()); + + docid = 4; + search->initRange(docid, 7); + match = search->seek(docid); + EXPECT_FALSE(match); + EXPECT_FALSE(search->isAtEnd()); + EXPECT_EQUAL(5u, search->getDocId()); + docid = 5; + search->unpack(docid); + EXPECT_EQUAL(1.0, tfmd.getRawScore()); + EXPECT_FALSE(search->isAtEnd()); + docid = 6; + match = search->seek(docid); + EXPECT_FALSE(match); + EXPECT_TRUE(search->isAtEnd()); +} + TEST_MAIN() { TEST_RUN_ALL(); } diff --git a/searchlib/src/tests/tensor/hnsw_index/hnsw_index_test.cpp b/searchlib/src/tests/tensor/hnsw_index/hnsw_index_test.cpp index cd0d4bcaad0..1204ae1e9bc 100644 --- a/searchlib/src/tests/tensor/hnsw_index/hnsw_index_test.cpp +++ b/searchlib/src/tests/tensor/hnsw_index/hnsw_index_test.cpp @@ -102,8 +102,10 @@ public: if (exp_hits.size() == k) { std::vector<uint32_t> expected_by_docid = exp_hits; std::sort(expected_by_docid.begin(), expected_by_docid.end()); - std::vector<uint32_t> got_by_docid = index->find_top_k(k, qv, k); - EXPECT_EQ(expected_by_docid, got_by_docid); + auto got_by_docid = index->find_top_k(k, qv, k); + for (idx = 0; idx < k; ++idx) { + EXPECT_EQ(expected_by_docid[idx], got_by_docid[idx].docid); + } } } }; diff --git a/searchlib/src/vespa/searchlib/queryeval/CMakeLists.txt b/searchlib/src/vespa/searchlib/queryeval/CMakeLists.txt index de2919443ff..0dcb0393473 100644 --- a/searchlib/src/vespa/searchlib/queryeval/CMakeLists.txt +++ b/searchlib/src/vespa/searchlib/queryeval/CMakeLists.txt @@ -32,6 +32,7 @@ vespa_add_library(searchlib_queryeval OBJECT nearest_neighbor_blueprint.cpp nearest_neighbor_iterator.cpp nearsearch.cpp + nns_index_iterator.cpp orsearch.cpp predicate_blueprint.cpp predicate_search.cpp diff --git a/searchlib/src/vespa/searchlib/queryeval/nns_index_iterator.cpp b/searchlib/src/vespa/searchlib/queryeval/nns_index_iterator.cpp new file mode 100644 index 00000000000..222f02d1941 --- /dev/null +++ b/searchlib/src/vespa/searchlib/queryeval/nns_index_iterator.cpp @@ -0,0 +1,65 @@ +// Copyright 2020 Oath Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. + +#include "nns_index_iterator.h" +#include <vespa/searchlib/tensor/nearest_neighbor_index.h> +#include <cmath> + +using Hit = search::tensor::NearestNeighborIndex::Neighbor; + +namespace search::queryeval { + +class NeighborVectorIterator : public NnsIndexIterator +{ +private: + fef::TermFieldMatchData &_tfmd; + const std::vector<Hit> &_hits; + uint32_t _idx; + double _last_sq_dist; +public: + NeighborVectorIterator(fef::TermFieldMatchData &tfmd, + const std::vector<Hit> &hits) + : _tfmd(tfmd), + _hits(hits), + _idx(0), + _last_sq_dist(0.0) + {} + + void initRange(uint32_t begin_id, uint32_t end_id) override { + SearchIterator::initRange(begin_id, end_id); + _idx = 0; + } + + void doSeek(uint32_t docId) override { + while (_idx < _hits.size()) { + uint32_t hit_id = _hits[_idx].docid; + if (hit_id < docId) { + ++_idx; + } else if (hit_id < getEndId()) { + setDocId(hit_id); + _last_sq_dist = _hits[_idx].distance; + return; + } else { + _idx = _hits.size(); + } + } + setAtEnd(); + } + + void doUnpack(uint32_t docId) override { + _tfmd.setRawScore(docId, sqrt(_last_sq_dist)); + } + + Trinary is_strict() const override { return Trinary::True; } +}; + +std::unique_ptr<NnsIndexIterator> +NnsIndexIterator::create( + bool strict, + fef::TermFieldMatchData &tfmd, + const std::vector<Hit> &hits) +{ + assert(strict); + return std::make_unique<NeighborVectorIterator>(tfmd, hits); +} + +} // namespace diff --git a/searchlib/src/vespa/searchlib/queryeval/nns_index_iterator.h b/searchlib/src/vespa/searchlib/queryeval/nns_index_iterator.h new file mode 100644 index 00000000000..62fa49aac46 --- /dev/null +++ b/searchlib/src/vespa/searchlib/queryeval/nns_index_iterator.h @@ -0,0 +1,21 @@ +// Copyright 2020 Oath Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. + +#pragma once + +#include "searchiterator.h" +#include <vespa/searchlib/fef/termfieldmatchdata.h> +#include <vespa/searchlib/tensor/nearest_neighbor_index.h> + +namespace search::queryeval { + +class NnsIndexIterator : public SearchIterator +{ +public: + using Hit = search::tensor::NearestNeighborIndex::Neighbor; + static std::unique_ptr<NnsIndexIterator> create( + bool strict, + fef::TermFieldMatchData &tfmd, + const std::vector<Hit> &hits); +}; + +} // namespace diff --git a/searchlib/src/vespa/searchlib/tensor/CMakeLists.txt b/searchlib/src/vespa/searchlib/tensor/CMakeLists.txt index 09069861ab4..0bdcd53af77 100644 --- a/searchlib/src/vespa/searchlib/tensor/CMakeLists.txt +++ b/searchlib/src/vespa/searchlib/tensor/CMakeLists.txt @@ -6,12 +6,13 @@ vespa_add_library(searchlib_tensor OBJECT dense_tensor_attribute_saver.cpp dense_tensor_store.cpp generic_tensor_attribute.cpp + generic_tensor_attribute_saver.cpp generic_tensor_store.cpp hnsw_index.cpp imported_tensor_attribute_vector.cpp imported_tensor_attribute_vector_read_guard.cpp + nearest_neighbor_index.cpp tensor_attribute.cpp - generic_tensor_attribute_saver.cpp tensor_store.cpp DEPENDS ) diff --git a/searchlib/src/vespa/searchlib/tensor/hnsw_index.cpp b/searchlib/src/vespa/searchlib/tensor/hnsw_index.cpp index 860686f3c6a..0d90e4e822a 100644 --- a/searchlib/src/vespa/searchlib/tensor/hnsw_index.cpp +++ b/searchlib/src/vespa/searchlib/tensor/hnsw_index.cpp @@ -310,19 +310,27 @@ HnswIndex::remove_document(uint32_t docid) _node_refs[docid].store_release(invalid); } -std::vector<uint32_t> +struct NeighborsByDocId { + bool operator() (const NearestNeighborIndex::Neighbor &lhs, + const NearestNeighborIndex::Neighbor &rhs) + { + return (lhs.docid < rhs.docid); + } +}; + +std::vector<NearestNeighborIndex::Neighbor> HnswIndex::find_top_k(uint32_t k, TypedCells vector, uint32_t explore_k) { - std::vector<uint32_t> result; + std::vector<Neighbor> result; FurthestPriQ candidates = top_k_candidates(vector, std::max(k, explore_k)); while (candidates.size() > k) { candidates.pop(); } result.reserve(candidates.size()); for (const HnswCandidate & hit : candidates.peek()) { - result.emplace_back(hit.docid); + result.emplace_back(hit.docid, hit.distance); } - std::sort(result.begin(), result.end()); + std::sort(result.begin(), result.end(), NeighborsByDocId()); return result; } diff --git a/searchlib/src/vespa/searchlib/tensor/hnsw_index.h b/searchlib/src/vespa/searchlib/tensor/hnsw_index.h index 66d6a6d25c2..ae9927be7a8 100644 --- a/searchlib/src/vespa/searchlib/tensor/hnsw_index.h +++ b/searchlib/src/vespa/searchlib/tensor/hnsw_index.h @@ -135,7 +135,7 @@ public: void add_document(uint32_t docid) override; void remove_document(uint32_t docid) override; - std::vector<uint32_t> find_top_k(uint32_t k, TypedCells vector, uint32_t explore_k) override; + std::vector<Neighbor> find_top_k(uint32_t k, TypedCells vector, uint32_t explore_k) override; FurthestPriQ top_k_candidates(const TypedCells &vector, uint32_t k); // TODO: Add support for generation handling and cleanup (transfer_hold_lists, trim_hold_lists) diff --git a/searchlib/src/vespa/searchlib/tensor/nearest_neighbor_index.cpp b/searchlib/src/vespa/searchlib/tensor/nearest_neighbor_index.cpp new file mode 100644 index 00000000000..f31230af381 --- /dev/null +++ b/searchlib/src/vespa/searchlib/tensor/nearest_neighbor_index.cpp @@ -0,0 +1,3 @@ +// Copyright 2020 Oath Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. + +#include "nearest_neighbor_index.h" diff --git a/searchlib/src/vespa/searchlib/tensor/nearest_neighbor_index.h b/searchlib/src/vespa/searchlib/tensor/nearest_neighbor_index.h index 2ae322fe76e..57e8beff84e 100644 --- a/searchlib/src/vespa/searchlib/tensor/nearest_neighbor_index.h +++ b/searchlib/src/vespa/searchlib/tensor/nearest_neighbor_index.h @@ -13,10 +13,20 @@ namespace search::tensor { */ class NearestNeighborIndex { public: + struct Neighbor { + uint32_t docid; + double distance; + Neighbor(uint32_t id, double dist) + : docid(id), distance(dist) + {} + Neighbor() : docid(0), distance(0.0) {} + }; virtual ~NearestNeighborIndex() {} virtual void add_document(uint32_t docid) = 0; virtual void remove_document(uint32_t docid) = 0; - virtual std::vector<uint32_t> find_top_k(uint32_t k, vespalib::tensor::TypedCells vector, uint32_t explore_k) = 0; + virtual std::vector<Neighbor> find_top_k(uint32_t k, + vespalib::tensor::TypedCells vector, + uint32_t explore_k) = 0; }; } |