summaryrefslogtreecommitdiffstats
path: root/searchlib/src/vespa/searchlib/tensor/hnsw_index.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'searchlib/src/vespa/searchlib/tensor/hnsw_index.cpp')
-rw-r--r--searchlib/src/vespa/searchlib/tensor/hnsw_index.cpp217
1 files changed, 191 insertions, 26 deletions
diff --git a/searchlib/src/vespa/searchlib/tensor/hnsw_index.cpp b/searchlib/src/vespa/searchlib/tensor/hnsw_index.cpp
index 345f7f551a6..9600f2fd9d4 100644
--- a/searchlib/src/vespa/searchlib/tensor/hnsw_index.cpp
+++ b/searchlib/src/vespa/searchlib/tensor/hnsw_index.cpp
@@ -1,37 +1,180 @@
// Copyright 2020 Oath Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+#include "distance_function.h"
#include "hnsw_index.h"
+#include "random_level_generator.h"
+#include <vespa/eval/tensor/dense/typed_cells.h>
+#include <vespa/vespalib/datastore/array_store.hpp>
#include <vespa/vespalib/util/rcuvector.hpp>
namespace search::tensor {
-template <typename FloatType>
+namespace {
+
+// TODO: Move this to MemoryAllocator, with name PAGE_SIZE.
+constexpr size_t small_page_size = 4 * 1024;
+constexpr size_t min_num_arrays_for_new_buffer = 8 * 1024;
+constexpr float alloc_grow_factor = 0.2;
+// TODO: Adjust these numbers to what we accept as max in config.
+constexpr size_t max_level_array_size = 16;
+constexpr size_t max_link_array_size = 64;
+
+}
+
+search::datastore::ArrayStoreConfig
+HnswIndex::make_default_node_store_config()
+{
+ return NodeStore::optimizedConfigForHugePage(max_level_array_size, vespalib::alloc::MemoryAllocator::HUGEPAGE_SIZE,
+ small_page_size, min_num_arrays_for_new_buffer, alloc_grow_factor).enable_free_lists(true);
+}
+
+search::datastore::ArrayStoreConfig
+HnswIndex::make_default_link_store_config()
+{
+ return LinkStore::optimizedConfigForHugePage(max_link_array_size, vespalib::alloc::MemoryAllocator::HUGEPAGE_SIZE,
+ small_page_size, min_num_arrays_for_new_buffer, alloc_grow_factor).enable_free_lists(true);
+}
+
+uint32_t
+HnswIndex::max_links_for_level(uint32_t level) const
+{
+ return (level == 0) ? _cfg.max_links_at_level_0() : _cfg.max_links_at_hierarchic_levels();
+}
+
+uint32_t
+HnswIndex::make_node_for_document(uint32_t docid)
+{
+ uint32_t max_level = _level_generator.max_level();
+ // TODO: Add capping on num_levels
+ uint32_t num_levels = max_level + 1;
+ // Note: The level array instance lives as long as the document is present in the index.
+ LevelArray levels(num_levels, AtomicEntryRef());
+ auto node_ref = _nodes.add(levels);
+ _node_refs[docid].store_release(node_ref);
+ return max_level;
+}
+
+HnswIndex::LevelArrayRef
+HnswIndex::get_level_array(uint32_t docid) const
+{
+ auto node_ref = _node_refs[docid].load_acquire();
+ return _nodes.get(node_ref);
+}
+
+HnswIndex::LinkArrayRef
+HnswIndex::get_link_array(uint32_t docid, uint32_t level) const
+{
+ auto levels = get_level_array(docid);
+ assert(level < levels.size());
+ return _links.get(levels[level].load_acquire());
+}
+
+void
+HnswIndex::set_link_array(uint32_t docid, uint32_t level, const LinkArrayRef& links)
+{
+ auto links_ref = _links.add(links);
+ auto node_ref = _node_refs[docid].load_acquire();
+ auto levels = _nodes.get_writable(node_ref);
+ levels[level].store_release(links_ref);
+}
+
+bool
+HnswIndex::have_closer_distance(HnswCandidate candidate, const LinkArray& result) const
+{
+ for (uint32_t result_docid : result) {
+ double dist = calc_distance(candidate.docid, result_docid);
+ if (dist < candidate.distance) {
+ return true;
+ }
+ }
+ return false;
+}
+
+HnswIndex::LinkArray
+HnswIndex::select_neighbors_simple(const HnswCandidateVector& neighbors, uint32_t max_links) const
+{
+ HnswCandidateVector sorted(neighbors);
+ std::sort(sorted.begin(), sorted.end(), LesserDistance());
+ LinkArray result;
+ for (size_t i = 0, m = std::min(static_cast<size_t>(max_links), sorted.size()); i < m; ++i) {
+ result.push_back(sorted[i].docid);
+ }
+ return result;
+}
+
+HnswIndex::LinkArray
+HnswIndex::select_neighbors_heuristic(const HnswCandidateVector& neighbors, uint32_t max_links) const
+{
+ LinkArray result;
+ bool need_filtering = neighbors.size() > max_links;
+ NearestPriQ nearest;
+ for (const auto& entry : neighbors) {
+ nearest.push(entry);
+ }
+ while (!nearest.empty()) {
+ auto candidate = nearest.top();
+ nearest.pop();
+ if (need_filtering && have_closer_distance(candidate, result)) {
+ continue;
+ }
+ result.push_back(candidate.docid);
+ if (result.size() == max_links) {
+ return result;
+ }
+ }
+ return result;
+}
+
+HnswIndex::LinkArray
+HnswIndex::select_neighbors(const HnswCandidateVector& neighbors, uint32_t max_links) const
+{
+ if (_cfg.heuristic_select_neighbors()) {
+ return select_neighbors_heuristic(neighbors, max_links);
+ } else {
+ return select_neighbors_simple(neighbors, max_links);
+ }
+}
+
+void
+HnswIndex::connect_new_node(uint32_t docid, const LinkArray& neighbors, uint32_t level)
+{
+ set_link_array(docid, level, neighbors);
+ for (uint32_t neighbor_docid : neighbors) {
+ auto old_links = get_link_array(neighbor_docid, level);
+ LinkArray new_links(old_links.begin(), old_links.end());
+ new_links.push_back(docid);
+ set_link_array(neighbor_docid, level, new_links);
+ }
+}
+
+void
+HnswIndex::remove_link_to(uint32_t remove_from, uint32_t remove_id, uint32_t level)
+{
+ LinkArray new_links;
+ auto old_links = get_link_array(remove_from, level);
+ for (uint32_t id : old_links) {
+ if (id != remove_id) new_links.push_back(id);
+ }
+ set_link_array(remove_from, level, new_links);
+}
+
+
double
-HnswIndex<FloatType>::calc_distance(uint32_t lhs_docid, uint32_t rhs_docid) const
+HnswIndex::calc_distance(uint32_t lhs_docid, uint32_t rhs_docid) const
{
auto lhs = get_vector(lhs_docid);
return calc_distance(lhs, rhs_docid);
}
-template <typename FloatType>
double
-HnswIndex<FloatType>::calc_distance(const Vector& lhs, uint32_t rhs_docid) const
+HnswIndex::calc_distance(const TypedCells& lhs, uint32_t rhs_docid) const
{
- // TODO: Make it possible to specify the distance function from the outside and make it hardware optimized.
auto rhs = get_vector(rhs_docid);
- double result = 0.0;
- size_t sz = lhs.size();
- assert(sz == rhs.size());
- for (size_t i = 0; i < sz; ++i) {
- double diff = lhs[i] - rhs[i];
- result += diff * diff;
- }
- return result;
+ return _distance_func.calc(lhs, rhs);
}
-template <typename FloatType>
HnswCandidate
-HnswIndex<FloatType>::find_nearest_in_layer(const Vector& input, const HnswCandidate& entry_point, uint32_t level)
+HnswIndex::find_nearest_in_layer(const TypedCells& input, const HnswCandidate& entry_point, uint32_t level)
{
HnswCandidate nearest = entry_point;
bool keep_searching = true;
@@ -48,9 +191,8 @@ HnswIndex<FloatType>::find_nearest_in_layer(const Vector& input, const HnswCandi
return nearest;
}
-template <typename FloatType>
void
-HnswIndex<FloatType>::search_layer(const Vector& input, uint32_t neighbors_to_find, FurthestPriQ& best_neighbors, uint32_t level)
+HnswIndex::search_layer(const TypedCells& input, uint32_t neighbors_to_find, FurthestPriQ& best_neighbors, uint32_t level)
{
NearestPriQ candidates;
// TODO: Add proper handling of visited set.
@@ -85,18 +227,24 @@ HnswIndex<FloatType>::search_layer(const Vector& input, uint32_t neighbors_to_fi
}
}
-template <typename FloatType>
-HnswIndex<FloatType>::HnswIndex(const DocVectorAccess& vectors, RandomLevelGenerator& level_generator, const Config& cfg)
- : HnswIndexBase(vectors, level_generator, cfg)
+HnswIndex::HnswIndex(const DocVectorAccess& vectors, const DistanceFunction& distance_func,
+ RandomLevelGenerator& level_generator, const Config& cfg)
+ : _vectors(vectors),
+ _distance_func(distance_func),
+ _level_generator(level_generator),
+ _cfg(cfg),
+ _node_refs(),
+ _nodes(make_default_node_store_config()),
+ _links(make_default_link_store_config()),
+ _entry_docid(0), // Note that docid 0 is reserved and never used
+ _entry_level(-1)
{
}
-template <typename FloatType>
-HnswIndex<FloatType>::~HnswIndex() = default;
+HnswIndex::~HnswIndex() = default;
-template <typename FloatType>
void
-HnswIndex<FloatType>::add_document(uint32_t docid)
+HnswIndex::add_document(uint32_t docid)
{
auto input = get_vector(docid);
_node_refs.ensure_size(docid + 1, AtomicEntryRef());
@@ -136,9 +284,8 @@ HnswIndex<FloatType>::add_document(uint32_t docid)
}
}
-template <typename FloatType>
void
-HnswIndex<FloatType>::remove_document(uint32_t docid)
+HnswIndex::remove_document(uint32_t docid)
{
bool need_new_entrypoint = (docid == _entry_docid);
LinkArray empty;
@@ -163,5 +310,23 @@ HnswIndex<FloatType>::remove_document(uint32_t docid)
_node_refs[docid].store_release(invalid);
}
+HnswNode
+HnswIndex::get_node(uint32_t docid) const
+{
+ auto node_ref = _node_refs[docid].load_acquire();
+ if (!node_ref.valid()) {
+ return HnswNode();
+ }
+ auto levels = _nodes.get(node_ref);
+ HnswNode::LevelArray result;
+ for (const auto& links_ref : levels) {
+ auto links = _links.get(links_ref.load_acquire());
+ HnswNode::LinkArray result_links(links.begin(), links.end());
+ std::sort(result_links.begin(), result_links.end());
+ result.push_back(result_links);
+ }
+ return HnswNode(result);
+}
+
}