aboutsummaryrefslogtreecommitdiffstats
path: root/searchlib/src/tests/tensor/hnsw_index/hnsw_index_test.cpp
diff options
context:
space:
mode:
authorTor Egge <Tor.Egge@online.no>2022-11-21 16:24:26 +0100
committerTor Egge <Tor.Egge@online.no>2022-11-21 16:24:26 +0100
commit67689d16d23ecc4b1a2de76ca08cc172ccea7a0f (patch)
tree6aa0a28a78127f96b970884b7030facfcaecebae /searchlib/src/tests/tensor/hnsw_index/hnsw_index_test.cpp
parent88a4c159d2fa483e6b1cbcfc7bc56667e3427828 (diff)
Update mapping from docid to nodeids when loading hnsw index.
Diffstat (limited to 'searchlib/src/tests/tensor/hnsw_index/hnsw_index_test.cpp')
-rw-r--r--searchlib/src/tests/tensor/hnsw_index/hnsw_index_test.cpp86
1 files changed, 86 insertions, 0 deletions
diff --git a/searchlib/src/tests/tensor/hnsw_index/hnsw_index_test.cpp b/searchlib/src/tests/tensor/hnsw_index/hnsw_index_test.cpp
index 5be4ae9d28f..b86913caa16 100644
--- a/searchlib/src/tests/tensor/hnsw_index/hnsw_index_test.cpp
+++ b/searchlib/src/tests/tensor/hnsw_index/hnsw_index_test.cpp
@@ -5,10 +5,13 @@
#include <vespa/searchlib/tensor/distance_functions.h>
#include <vespa/searchlib/tensor/doc_vector_access.h>
#include <vespa/searchlib/tensor/hnsw_index.h>
+#include <vespa/searchlib/tensor/hnsw_index_loader.hpp>
+#include <vespa/searchlib/tensor/hnsw_index_saver.h>
#include <vespa/searchlib/tensor/random_level_generator.h>
#include <vespa/searchlib/tensor/inv_log_level_generator.h>
#include <vespa/searchlib/tensor/subspace_type.h>
#include <vespa/searchlib/tensor/vector_bundle.h>
+#include <vespa/searchlib/util/bufferwriter.h>
#include <vespa/searchlib/queryeval/global_filter.h>
#include <vespa/vespalib/datastore/compaction_spec.h>
#include <vespa/vespalib/datastore/compaction_strategy.h>
@@ -27,12 +30,46 @@ using namespace search::tensor;
using namespace vespalib::slime;
using vespalib::Slime;
using search::BitVector;
+using search::BufferWriter;
using vespalib::eval::get_cell_type;
using vespalib::eval::ValueType;
using vespalib::datastore::CompactionSpec;
using vespalib::datastore::CompactionStrategy;
using search::queryeval::GlobalFilter;
+class VectorBufferWriter : public BufferWriter {
+private:
+ char tmp[1024];
+public:
+ std::vector<char> output;
+ VectorBufferWriter() {
+ setup(tmp, 1024);
+ }
+ ~VectorBufferWriter() {}
+ void flush() override {
+ for (size_t i = 0; i < usedLen(); ++i) {
+ output.push_back(tmp[i]);
+ }
+ rewind();
+ }
+};
+
+class VectorBufferReader {
+private:
+ const std::vector<char>& _data;
+ size_t _pos;
+
+public:
+ VectorBufferReader(const std::vector<char>& data) : _data(data), _pos(0) {}
+ uint32_t readHostOrder() {
+ uint32_t result = 0;
+ assert(_pos + sizeof(uint32_t) <= _data.size());
+ std::memcpy(&result, _data.data() + _pos, sizeof(uint32_t));
+ _pos += sizeof(uint32_t);
+ return result;
+ }
+};
+
template <typename FloatType>
class MyDocVectorAccess : public DocVectorAccess {
private:
@@ -195,6 +232,44 @@ public:
FloatVectors& get_vectors() { return vectors; }
+ uint32_t get_single_nodeid(uint32_t docid) {
+ auto& id_mapping = index->get_id_mapping();
+ auto nodeids = id_mapping.get_ids(docid);
+ EXPECT_EQ(1, nodeids.size());
+ return nodeids[0];
+ }
+
+ void make_savetest_index()
+ {
+ this->add_document(7);
+ this->add_document(4);
+ }
+
+ void check_savetest_index(const vespalib::string& label) {
+ SCOPED_TRACE(label);
+ auto nodeid_for_doc_7 = get_single_nodeid(7);
+ auto nodeid_for_doc_4 = get_single_nodeid(4);
+ EXPECT_EQ(is_single ? 7 : 1, nodeid_for_doc_7);
+ EXPECT_EQ(is_single ? 4 : 2, nodeid_for_doc_4);
+ this->expect_level_0(nodeid_for_doc_7, { nodeid_for_doc_4 });
+ this->expect_level_0(nodeid_for_doc_4, { nodeid_for_doc_7 });
+ }
+
+ std::vector<char> save_index() const {
+ HnswIndexSaver saver(index->get_graph());
+ VectorBufferWriter vector_writer;
+ saver.save(vector_writer);
+ return vector_writer.output;
+ }
+
+ void load_index(std::vector<char> data) {
+ auto& graph = index->get_graph();
+ HnswIndexLoader<VectorBufferReader, IndexType::index_type> loader(graph, std::make_unique<VectorBufferReader>(data));
+ while (loader.load_next()) {}
+ auto& id_mapping = index->get_id_mapping();
+ id_mapping.on_load(graph.node_refs.make_read_view(graph.node_refs.size()));
+ }
+
static constexpr bool is_single = std::is_same_v<IndexType, HnswIndex<HnswIndexType::SINGLE>>;
};
@@ -687,6 +762,17 @@ TYPED_TEST(HnswIndexTest, hnsw_graph_is_compacted)
EXPECT_LT(mem_3.usedBytes(), mem_2.usedBytes());
}
+TYPED_TEST(HnswIndexTest, hnsw_graph_can_be_saved_and_loaded)
+{
+ this->init(false);
+ this->make_savetest_index();
+ this->check_savetest_index("before save");
+ auto data = this->save_index();
+ this->init(false);
+ this->load_index(data);
+ this->check_savetest_index("after load");
+ }
+
TEST(LevelGeneratorTest, gives_various_levels)
{
InvLogLevelGenerator generator(4);