diff options
author | Arne Juul <arnej@verizonmedia.com> | 2020-02-17 14:07:24 +0000 |
---|---|---|
committer | Arne Juul <arnej@verizonmedia.com> | 2020-02-17 14:07:24 +0000 |
commit | 1ab187d4152775a22d1b554fd2364ec5058cdfee (patch) | |
tree | 220c359f1dc244174748d041756d56655b814c77 /searchlib | |
parent | f4d3589c6b00c330246fa486afc5fe20f6a0f9fe (diff) |
add search API in NearestNeighborIndex
Diffstat (limited to 'searchlib')
4 files changed, 22 insertions, 3 deletions
diff --git a/searchlib/src/tests/tensor/hnsw_index/hnsw_index_test.cpp b/searchlib/src/tests/tensor/hnsw_index/hnsw_index_test.cpp index 1965d58c0e7..805012d224c 100644 --- a/searchlib/src/tests/tensor/hnsw_index/hnsw_index_test.cpp +++ b/searchlib/src/tests/tensor/hnsw_index/hnsw_index_test.cpp @@ -87,7 +87,7 @@ public: EXPECT_EQ(exp_levels, act_node.levels()); } void expect_top_3(uint32_t docid, std::vector<uint32_t> exp_hits) { - auto rv = index->find_top_k(vectors.get_vector(docid), 3); + auto rv = index->top_k_candidates(vectors.get_vector(docid), 3); size_t idx = 0; for (const auto & hit : rv) { // fprintf(stderr, "found docid %u dist %.1f\n", hit.docid, hit.distance); diff --git a/searchlib/src/vespa/searchlib/tensor/hnsw_index.cpp b/searchlib/src/vespa/searchlib/tensor/hnsw_index.cpp index 85010be1e17..4bed9097a5c 100644 --- a/searchlib/src/vespa/searchlib/tensor/hnsw_index.cpp +++ b/searchlib/src/vespa/searchlib/tensor/hnsw_index.cpp @@ -310,8 +310,23 @@ HnswIndex::remove_document(uint32_t docid) _node_refs[docid].store_release(invalid); } +std::vector<uint32_t> +HnswIndex::find_top_k(TypedCells vector, uint32_t k) +{ + std::vector<uint32_t> result; + std::vector<HnswCandidate> candidates = top_k_candidates(vector, k + 100); + result.reserve(std::min((size_t)k, candidates.size())); + for (const HnswCandidate & hit : candidates) { + result.emplace_back(hit.docid); + if (result.size() == k) break; + } + std::sort(result.begin(), result.end()); + return result; +} + + std::vector<HnswCandidate> -HnswIndex::find_top_k(const TypedCells &vector, uint32_t k) +HnswIndex::top_k_candidates(const TypedCells &vector, uint32_t k) { std::vector<HnswCandidate> result; if (_entry_level < 0) { diff --git a/searchlib/src/vespa/searchlib/tensor/hnsw_index.h b/searchlib/src/vespa/searchlib/tensor/hnsw_index.h index af20dfd9a78..2454ff85884 100644 --- a/searchlib/src/vespa/searchlib/tensor/hnsw_index.h +++ b/searchlib/src/vespa/searchlib/tensor/hnsw_index.h @@ -134,7 +134,8 @@ public: void add_document(uint32_t docid) override; void remove_document(uint32_t docid) override; - std::vector<HnswCandidate> find_top_k(const TypedCells &vector, uint32_t k); + std::vector<uint32_t> find_top_k(TypedCells vector, uint32_t k) override; + std::vector<HnswCandidate> top_k_candidates(const TypedCells &vector, uint32_t k); // TODO: Add support for generation handling and cleanup (transfer_hold_lists, trim_hold_lists) diff --git a/searchlib/src/vespa/searchlib/tensor/nearest_neighbor_index.h b/searchlib/src/vespa/searchlib/tensor/nearest_neighbor_index.h index 2167157f6cb..718b96c92b4 100644 --- a/searchlib/src/vespa/searchlib/tensor/nearest_neighbor_index.h +++ b/searchlib/src/vespa/searchlib/tensor/nearest_neighbor_index.h @@ -3,6 +3,8 @@ #pragma once #include <cstdint> +#include <vector> +#include <vespa/eval/tensor/dense/typed_cells.h> namespace search::tensor { @@ -14,6 +16,7 @@ public: virtual ~NearestNeighborIndex() {} virtual void add_document(uint32_t docid) = 0; virtual void remove_document(uint32_t docid) = 0; + virtual std::vector<uint32_t> find_top_k(vespalib::tensor::TypedCells vector, uint32_t k) = 0; }; } |