diff options
Diffstat (limited to 'searchlib/src/tests/tensor/hnsw_saver/hnsw_save_load_test.cpp')
-rw-r--r-- | searchlib/src/tests/tensor/hnsw_saver/hnsw_save_load_test.cpp | 64 |
1 files changed, 26 insertions, 38 deletions
diff --git a/searchlib/src/tests/tensor/hnsw_saver/hnsw_save_load_test.cpp b/searchlib/src/tests/tensor/hnsw_saver/hnsw_save_load_test.cpp index e2a96ec059c..bf4abdd7cf8 100644 --- a/searchlib/src/tests/tensor/hnsw_saver/hnsw_save_load_test.cpp +++ b/searchlib/src/tests/tensor/hnsw_saver/hnsw_save_load_test.cpp @@ -1,9 +1,13 @@ // Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. #include <vespa/searchlib/tensor/hnsw_graph.h> +#include <vespa/searchlib/tensor/hnsw_identity_mapping.h> #include <vespa/searchlib/tensor/hnsw_index_saver.h> #include <vespa/searchlib/tensor/hnsw_index_loader.hpp> -#include <vespa/searchlib/util/bufferwriter.h> +#include <vespa/searchlib/tensor/hnsw_index_traits.h> +#include <vespa/searchlib/tensor/hnsw_nodeid_mapping.h> +#include <vespa/searchlib/test/vector_buffer_reader.h> +#include <vespa/searchlib/test/vector_buffer_writer.h> #include <vespa/searchlib/util/fileutil.h> #include <vespa/vespalib/gtest/gtest.h> #include <vector> @@ -14,39 +18,8 @@ LOG_SETUP("hnsw_save_load_test"); using namespace search::tensor; using search::BufferWriter; using search::fileutil::LoadedBuffer; - -class VectorBufferWriter : public BufferWriter { -private: - char tmp[1024]; -public: - std::vector<char> output; - VectorBufferWriter() { - setup(tmp, 1024); - } - ~VectorBufferWriter() {} - void flush() override { - for (size_t i = 0; i < usedLen(); ++i) { - output.push_back(tmp[i]); - } - rewind(); - } -}; - -class VectorBufferReader { -private: - const std::vector<char>& _data; - size_t _pos; - -public: - VectorBufferReader(const std::vector<char>& data) : _data(data), _pos(0) {} - uint32_t readHostOrder() { - uint32_t result = 0; - assert(_pos + sizeof(uint32_t) <= _data.size()); - std::memcpy(&result, _data.data() + _pos, sizeof(uint32_t)); - _pos += sizeof(uint32_t); - return result; - } -}; +using search::test::VectorBufferReader; +using search::test::VectorBufferWriter; using V = std::vector<uint32_t>; @@ -62,7 +35,14 @@ uint32_t fake_docid<HnswIndexType::SINGLE>(uint32_t nodeid) template <> uint32_t fake_docid<HnswIndexType::MULTI>(uint32_t nodeid) { - return nodeid + 100; + switch (nodeid) { + case 5: + return 104; + case 6: + return 104; + default: + return nodeid + 100; + } } template <HnswIndexType type> @@ -77,7 +57,14 @@ uint32_t fake_subspace<HnswIndexType::SINGLE>(uint32_t) template <> uint32_t fake_subspace<HnswIndexType::MULTI>(uint32_t nodeid) { - return nodeid + 10; + switch (nodeid) { + case 5: + return 2; + case 6: + return 1; + default: + return 0; + } } template <typename NodeType> @@ -99,7 +86,7 @@ template <HnswIndexType type> void populate(HnswGraph<type> &graph) { // no 0 graph.make_node(1, fake_docid<type>(1), fake_subspace<type>(1), 1); - auto er = graph.make_node(2, 102, 12, 2); + auto er = graph.make_node(2, fake_docid<type>(2), fake_subspace<type>(2), 2); // no 3 graph.make_node(4, fake_docid<type>(4), fake_subspace<type>(4), 2); graph.make_node(5, fake_docid<type>(5), fake_subspace<type>(5), 0); @@ -167,7 +154,8 @@ public: return vector_writer.output; } void load_copy(std::vector<char> data) { - HnswIndexLoader<VectorBufferReader, GraphType::index_type> loader(copy, std::make_unique<VectorBufferReader>(data)); + typename HnswIndexTraits<GraphType::index_type>::IdMapping id_mapping; + HnswIndexLoader<VectorBufferReader, GraphType::index_type> loader(copy, id_mapping, std::make_unique<VectorBufferReader>(data)); while (loader.load_next()) {} } |