aboutsummaryrefslogtreecommitdiffstats
path: root/searchlib/src/tests/tensor/hnsw_saver/hnsw_save_load_test.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'searchlib/src/tests/tensor/hnsw_saver/hnsw_save_load_test.cpp')
-rw-r--r--searchlib/src/tests/tensor/hnsw_saver/hnsw_save_load_test.cpp64
1 files changed, 26 insertions, 38 deletions
diff --git a/searchlib/src/tests/tensor/hnsw_saver/hnsw_save_load_test.cpp b/searchlib/src/tests/tensor/hnsw_saver/hnsw_save_load_test.cpp
index e2a96ec059c..bf4abdd7cf8 100644
--- a/searchlib/src/tests/tensor/hnsw_saver/hnsw_save_load_test.cpp
+++ b/searchlib/src/tests/tensor/hnsw_saver/hnsw_save_load_test.cpp
@@ -1,9 +1,13 @@
// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
#include <vespa/searchlib/tensor/hnsw_graph.h>
+#include <vespa/searchlib/tensor/hnsw_identity_mapping.h>
#include <vespa/searchlib/tensor/hnsw_index_saver.h>
#include <vespa/searchlib/tensor/hnsw_index_loader.hpp>
-#include <vespa/searchlib/util/bufferwriter.h>
+#include <vespa/searchlib/tensor/hnsw_index_traits.h>
+#include <vespa/searchlib/tensor/hnsw_nodeid_mapping.h>
+#include <vespa/searchlib/test/vector_buffer_reader.h>
+#include <vespa/searchlib/test/vector_buffer_writer.h>
#include <vespa/searchlib/util/fileutil.h>
#include <vespa/vespalib/gtest/gtest.h>
#include <vector>
@@ -14,39 +18,8 @@ LOG_SETUP("hnsw_save_load_test");
using namespace search::tensor;
using search::BufferWriter;
using search::fileutil::LoadedBuffer;
-
-class VectorBufferWriter : public BufferWriter {
-private:
- char tmp[1024];
-public:
- std::vector<char> output;
- VectorBufferWriter() {
- setup(tmp, 1024);
- }
- ~VectorBufferWriter() {}
- void flush() override {
- for (size_t i = 0; i < usedLen(); ++i) {
- output.push_back(tmp[i]);
- }
- rewind();
- }
-};
-
-class VectorBufferReader {
-private:
- const std::vector<char>& _data;
- size_t _pos;
-
-public:
- VectorBufferReader(const std::vector<char>& data) : _data(data), _pos(0) {}
- uint32_t readHostOrder() {
- uint32_t result = 0;
- assert(_pos + sizeof(uint32_t) <= _data.size());
- std::memcpy(&result, _data.data() + _pos, sizeof(uint32_t));
- _pos += sizeof(uint32_t);
- return result;
- }
-};
+using search::test::VectorBufferReader;
+using search::test::VectorBufferWriter;
using V = std::vector<uint32_t>;
@@ -62,7 +35,14 @@ uint32_t fake_docid<HnswIndexType::SINGLE>(uint32_t nodeid)
template <>
uint32_t fake_docid<HnswIndexType::MULTI>(uint32_t nodeid)
{
- return nodeid + 100;
+ switch (nodeid) {
+ case 5:
+ return 104;
+ case 6:
+ return 104;
+ default:
+ return nodeid + 100;
+ }
}
template <HnswIndexType type>
@@ -77,7 +57,14 @@ uint32_t fake_subspace<HnswIndexType::SINGLE>(uint32_t)
template <>
uint32_t fake_subspace<HnswIndexType::MULTI>(uint32_t nodeid)
{
- return nodeid + 10;
+ switch (nodeid) {
+ case 5:
+ return 2;
+ case 6:
+ return 1;
+ default:
+ return 0;
+ }
}
template <typename NodeType>
@@ -99,7 +86,7 @@ template <HnswIndexType type>
void populate(HnswGraph<type> &graph) {
// no 0
graph.make_node(1, fake_docid<type>(1), fake_subspace<type>(1), 1);
- auto er = graph.make_node(2, 102, 12, 2);
+ auto er = graph.make_node(2, fake_docid<type>(2), fake_subspace<type>(2), 2);
// no 3
graph.make_node(4, fake_docid<type>(4), fake_subspace<type>(4), 2);
graph.make_node(5, fake_docid<type>(5), fake_subspace<type>(5), 0);
@@ -167,7 +154,8 @@ public:
return vector_writer.output;
}
void load_copy(std::vector<char> data) {
- HnswIndexLoader<VectorBufferReader, GraphType::index_type> loader(copy, std::make_unique<VectorBufferReader>(data));
+ typename HnswIndexTraits<GraphType::index_type>::IdMapping id_mapping;
+ HnswIndexLoader<VectorBufferReader, GraphType::index_type> loader(copy, id_mapping, std::make_unique<VectorBufferReader>(data));
while (loader.load_next()) {}
}