diff options
5 files changed, 34 insertions, 11 deletions
diff --git a/searchlib/src/tests/tensor/hnsw_index/hnsw_index_test.cpp b/searchlib/src/tests/tensor/hnsw_index/hnsw_index_test.cpp index 13caad8b6d6..bea371c78a8 100644 --- a/searchlib/src/tests/tensor/hnsw_index/hnsw_index_test.cpp +++ b/searchlib/src/tests/tensor/hnsw_index/hnsw_index_test.cpp @@ -234,10 +234,9 @@ public: void load_index(std::vector<char> data) { auto& graph = index->get_graph(); - HnswIndexLoader<VectorBufferReader, IndexType::index_type> loader(graph, std::make_unique<VectorBufferReader>(data)); - while (loader.load_next()) {} auto& id_mapping = index->get_id_mapping(); - id_mapping.on_load(graph.node_refs.make_read_view(graph.node_refs.size())); + HnswIndexLoader<VectorBufferReader, IndexType::index_type> loader(graph, id_mapping, std::make_unique<VectorBufferReader>(data)); + while (loader.load_next()) {} } static constexpr bool is_single = std::is_same_v<IndexType, HnswIndex<HnswIndexType::SINGLE>>; diff --git a/searchlib/src/tests/tensor/hnsw_saver/hnsw_save_load_test.cpp b/searchlib/src/tests/tensor/hnsw_saver/hnsw_save_load_test.cpp index 7495fa18c4d..bf4abdd7cf8 100644 --- a/searchlib/src/tests/tensor/hnsw_saver/hnsw_save_load_test.cpp +++ b/searchlib/src/tests/tensor/hnsw_saver/hnsw_save_load_test.cpp @@ -1,8 +1,11 @@ // Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. #include <vespa/searchlib/tensor/hnsw_graph.h> +#include <vespa/searchlib/tensor/hnsw_identity_mapping.h> #include <vespa/searchlib/tensor/hnsw_index_saver.h> #include <vespa/searchlib/tensor/hnsw_index_loader.hpp> +#include <vespa/searchlib/tensor/hnsw_index_traits.h> +#include <vespa/searchlib/tensor/hnsw_nodeid_mapping.h> #include <vespa/searchlib/test/vector_buffer_reader.h> #include <vespa/searchlib/test/vector_buffer_writer.h> #include <vespa/searchlib/util/fileutil.h> @@ -32,7 +35,14 @@ uint32_t fake_docid<HnswIndexType::SINGLE>(uint32_t nodeid) template <> uint32_t fake_docid<HnswIndexType::MULTI>(uint32_t nodeid) { - return nodeid + 100; + switch (nodeid) { + case 5: + return 104; + case 6: + return 104; + default: + return nodeid + 100; + } } template <HnswIndexType type> @@ -47,7 +57,14 @@ uint32_t fake_subspace<HnswIndexType::SINGLE>(uint32_t) template <> uint32_t fake_subspace<HnswIndexType::MULTI>(uint32_t nodeid) { - return nodeid + 10; + switch (nodeid) { + case 5: + return 2; + case 6: + return 1; + default: + return 0; + } } template <typename NodeType> @@ -69,7 +86,7 @@ template <HnswIndexType type> void populate(HnswGraph<type> &graph) { // no 0 graph.make_node(1, fake_docid<type>(1), fake_subspace<type>(1), 1); - auto er = graph.make_node(2, 102, 12, 2); + auto er = graph.make_node(2, fake_docid<type>(2), fake_subspace<type>(2), 2); // no 3 graph.make_node(4, fake_docid<type>(4), fake_subspace<type>(4), 2); graph.make_node(5, fake_docid<type>(5), fake_subspace<type>(5), 0); @@ -137,7 +154,8 @@ public: return vector_writer.output; } void load_copy(std::vector<char> data) { - HnswIndexLoader<VectorBufferReader, GraphType::index_type> loader(copy, std::make_unique<VectorBufferReader>(data)); + typename HnswIndexTraits<GraphType::index_type>::IdMapping id_mapping; + HnswIndexLoader<VectorBufferReader, GraphType::index_type> loader(copy, id_mapping, std::make_unique<VectorBufferReader>(data)); while (loader.load_next()) {} } diff --git a/searchlib/src/vespa/searchlib/tensor/hnsw_index.cpp b/searchlib/src/vespa/searchlib/tensor/hnsw_index.cpp index ce9f1ad9ca7..e9e52301f8e 100644 --- a/searchlib/src/vespa/searchlib/tensor/hnsw_index.cpp +++ b/searchlib/src/vespa/searchlib/tensor/hnsw_index.cpp @@ -777,7 +777,7 @@ HnswIndex<type>::make_loader(FastOS_FileInterface& file) assert(get_entry_nodeid() == 0); // cannot load after index has data using ReaderType = FileReader<uint32_t>; using LoaderType = HnswIndexLoader<ReaderType, type>; - return std::make_unique<LoaderType>(_graph, std::make_unique<ReaderType>(&file)); + return std::make_unique<LoaderType>(_graph, _id_mapping, std::make_unique<ReaderType>(&file)); } struct NeighborsByDocId { diff --git a/searchlib/src/vespa/searchlib/tensor/hnsw_index_loader.h b/searchlib/src/vespa/searchlib/tensor/hnsw_index_loader.h index efe15011776..721276ef0ab 100644 --- a/searchlib/src/vespa/searchlib/tensor/hnsw_index_loader.h +++ b/searchlib/src/vespa/searchlib/tensor/hnsw_index_loader.h @@ -3,6 +3,7 @@ #pragma once #include "nearest_neighbor_index_loader.h" +#include "hnsw_index_traits.h" #include <vespa/vespalib/util/exceptions.h> #include <cstdint> #include <memory> @@ -21,6 +22,8 @@ struct HnswGraph; template <typename ReaderType, HnswIndexType type> class HnswIndexLoader : public NearestNeighborIndexLoader { private: + using IdMapping = typename HnswIndexTraits<type>::IdMapping; + HnswGraph<type>& _graph; std::unique_ptr<ReaderType> _reader; uint32_t _entry_nodeid; @@ -29,6 +32,7 @@ private: uint32_t _nodeid; std::vector<uint32_t> _link_array; bool _complete; + IdMapping& _id_mapping; void init(); uint32_t next_int() { @@ -36,7 +40,7 @@ private: } public: - HnswIndexLoader(HnswGraph<type>& graph, std::unique_ptr<ReaderType> reader); + HnswIndexLoader(HnswGraph<type>& graph, IdMapping& id_mapping, std::unique_ptr<ReaderType> reader); virtual ~HnswIndexLoader(); bool load_next() override; }; diff --git a/searchlib/src/vespa/searchlib/tensor/hnsw_index_loader.hpp b/searchlib/src/vespa/searchlib/tensor/hnsw_index_loader.hpp index 04e1fcc1792..de9cc760fec 100644 --- a/searchlib/src/vespa/searchlib/tensor/hnsw_index_loader.hpp +++ b/searchlib/src/vespa/searchlib/tensor/hnsw_index_loader.hpp @@ -22,7 +22,7 @@ template <typename ReaderType, HnswIndexType type> HnswIndexLoader<ReaderType, type>::~HnswIndexLoader() = default; template <typename ReaderType, HnswIndexType type> -HnswIndexLoader<ReaderType, type>::HnswIndexLoader(HnswGraph<type>& graph, std::unique_ptr<ReaderType> reader) +HnswIndexLoader<ReaderType, type>::HnswIndexLoader(HnswGraph<type>& graph, IdMapping& id_mapping, std::unique_ptr<ReaderType> reader) : _graph(graph), _reader(std::move(reader)), _entry_nodeid(0), @@ -30,7 +30,8 @@ HnswIndexLoader<ReaderType, type>::HnswIndexLoader(HnswGraph<type>& graph, std:: _num_nodes(0), _nodeid(0), _link_array(), - _complete(false) + _complete(false), + _id_mapping(id_mapping) { init(); } @@ -65,6 +66,7 @@ HnswIndexLoader<ReaderType, type>::load_next() _graph.trim_node_refs_size(); auto entry_node_ref = _graph.get_node_ref(_entry_nodeid); _graph.set_entry_node({_entry_nodeid, entry_node_ref, _entry_level}); + _id_mapping.on_load(_graph.node_refs.make_read_view(_graph.node_refs.size())); _complete = true; return false; } |