aboutsummaryrefslogtreecommitdiffstats
path: root/searchlib/src/tests
diff options
context:
space:
mode:
authorTor Egge <Tor.Egge@online.no>2022-11-23 11:22:10 +0100
committerTor Egge <Tor.Egge@online.no>2022-11-23 11:22:10 +0100
commitea7389254797d0b45940439ad3d7c7f3370b3af1 (patch)
treede4ac043a312a679a81ffa2d451bc1599c0678db /searchlib/src/tests
parente03f1e82952dbdb801e737de41b285c0fa74c3f9 (diff)
Setup hnsw index for mixed tensor types.
Diffstat (limited to 'searchlib/src/tests')
-rw-r--r--searchlib/src/tests/attribute/tensorattribute/tensorattribute_test.cpp125
1 files changed, 108 insertions, 17 deletions
diff --git a/searchlib/src/tests/attribute/tensorattribute/tensorattribute_test.cpp b/searchlib/src/tests/attribute/tensorattribute/tensorattribute_test.cpp
index 8a3be423457..6fa1bdcf072 100644
--- a/searchlib/src/tests/attribute/tensorattribute/tensorattribute_test.cpp
+++ b/searchlib/src/tests/attribute/tensorattribute/tensorattribute_test.cpp
@@ -19,6 +19,7 @@
#include <vespa/searchlib/util/fileutil.h>
#include <vespa/searchcommon/attribute/config.h>
#include <vespa/vespalib/data/fileheader.h>
+#include <vespa/vespalib/stllike/asciistream.h>
#include <vespa/vespalib/test/insertion_operators.h>
#include <vespa/vespalib/testkit/test_kit.h>
#include <vespa/vespalib/util/mmap_file_allocator_factory.h>
@@ -72,6 +73,7 @@ using generation_t = vespalib::GenerationHandler::generation_t;
vespalib::string sparseSpec("tensor(x{},y{})");
vespalib::string denseSpec("tensor(x[2],y[3])");
vespalib::string vec_2d_spec("tensor(x[2])");
+vespalib::string vec_mixed_2d_spec("tensor(a{},x[2])");
Value::UP createTensor(const TensorSpec &spec) {
return SimpleValue::from_spec(spec);
@@ -83,6 +85,31 @@ vec_2d(double x0, double x1)
return TensorSpec(vec_2d_spec).add({{"x", 0}}, x0).add({{"x", 1}}, x1);
}
+TensorSpec
+vec_mixed_2d(std::vector<std::vector<double>> val)
+{
+ TensorSpec spec(vec_mixed_2d_spec);
+ for (uint32_t a = 0; a < val.size(); ++a) {
+ vespalib::asciistream a_stream;
+ a_stream << a;
+ vespalib::string a_as_string = a_stream.str();
+ for (uint32_t x = 0; x < val[a].size(); ++x) {
+ spec.add({{"a", a_as_string.c_str()},{"x", x}}, val[a][x]);
+ }
+ }
+ return spec;
+}
+
+TensorSpec
+typed_vec_2d(HnswIndexType type, double x0, double x1)
+{
+ if (type == HnswIndexType::SINGLE) {
+ return vec_2d(x0, x1);
+ } else {
+ return vec_mixed_2d({{x0, x1}});
+ }
+}
+
class MockIndexSaver : public NearestNeighborIndexSaver {
private:
int _index_value;
@@ -274,7 +301,6 @@ public:
return std::vector<Neighbor>();
}
-
const search::tensor::DistanceFunction *distance_function() const override {
static search::tensor::SquaredEuclideanDistance my_dist_fun(vespalib::eval::CellType::DOUBLE);
return &my_dist_fun;
@@ -285,10 +311,12 @@ class MockNearestNeighborIndexFactory : public NearestNeighborIndexFactory {
std::unique_ptr<NearestNeighborIndex> make(const DocVectorAccess& vectors,
size_t vector_size,
+ bool multi_vector_index,
CellType cell_type,
const search::attribute::HnswIndexParams& params) const override {
(void) vector_size;
(void) params;
+ (void) multi_vector_index;
assert(cell_type == CellType::DOUBLE);
return std::make_unique<MockNearestNeighborIndex>(vectors);
}
@@ -322,6 +350,13 @@ struct FixtureTraits {
return *this;
}
+ FixtureTraits mixed_hnsw() && {
+ use_dense_tensor_attribute = false;
+ enable_hnsw_index = true;
+ use_mock_index = false;
+ return *this;
+ }
+
FixtureTraits mock_hnsw() && {
use_dense_tensor_attribute = true;
enable_hnsw_index = true;
@@ -406,8 +441,8 @@ struct Fixture {
template <typename IndexType>
IndexType& get_nearest_neighbor_index() {
- assert(as_dense_tensor().nearest_neighbor_index() != nullptr);
- auto index = dynamic_cast<const IndexType*>(as_dense_tensor().nearest_neighbor_index());
+ assert(_tensorAttr->nearest_neighbor_index() != nullptr);
+ auto index = dynamic_cast<const IndexType*>(_tensorAttr->nearest_neighbor_index());
assert(index != nullptr);
return *const_cast<IndexType*>(index);
}
@@ -416,6 +451,11 @@ struct Fixture {
return get_nearest_neighbor_index<HnswIndex<HnswIndexType::SINGLE>>();
}
+ template <HnswIndexType type>
+ HnswIndex<type>& hnsw_typed_index() {
+ return get_nearest_neighbor_index<HnswIndex<type>>();
+ }
+
MockNearestNeighborIndex& mock_index() {
return get_nearest_neighbor_index<MockNearestNeighborIndex>();
}
@@ -836,15 +876,24 @@ TEST_F("Hnsw index is NOT instantiated in dense tensor attribute by default",
EXPECT_TRUE(tensor.nearest_neighbor_index() == nullptr);
}
-class DenseTensorAttributeHnswIndex : public Fixture {
+
+template <HnswIndexType type>
+class TensorAttributeHnswIndex : public Fixture
+{
public:
- DenseTensorAttributeHnswIndex() : Fixture(vec_2d_spec, FixtureTraits().hnsw()) {}
+ TensorAttributeHnswIndex(const vespalib::string &type_spec, FixtureTraits traits)
+ : Fixture(type_spec, traits)
+ {
+ }
+ void test_setup();
+ void test_save_load(bool multi_node);
};
-TEST_F("Hnsw index is instantiated in dense tensor attribute when specified in config", DenseTensorAttributeHnswIndex)
+template <HnswIndexType type>
+void
+TensorAttributeHnswIndex<type>::test_setup()
{
- auto& index = f.hnsw_index();
-
+ auto& index = hnsw_typed_index<type>();
const auto& cfg = index.config();
EXPECT_EQUAL(8u, cfg.max_links_at_level_0());
EXPECT_EQUAL(4u, cfg.max_links_on_inserts());
@@ -853,32 +902,74 @@ TEST_F("Hnsw index is instantiated in dense tensor attribute when specified in c
}
void
-expect_level_0(uint32_t exp_docid, const HnswTestNode& node)
+expect_level_0(uint32_t exp_nodeid, const HnswTestNode& node)
{
ASSERT_GREATER_EQUAL(node.size(), 1u);
ASSERT_EQUAL(1u, node.level(0).size());
- EXPECT_EQUAL(exp_docid, node.level(0)[0]);
+ EXPECT_EQUAL(exp_nodeid, node.level(0)[0]);
}
-TEST_F("Hnsw index is integrated in dense tensor attribute and can be saved and loaded", DenseTensorAttributeHnswIndex)
+template <HnswIndexType type>
+void
+TensorAttributeHnswIndex<type>::test_save_load(bool multi_node)
{
// Set two points that will be linked together in level 0 of the hnsw graph.
- f.set_tensor(1, vec_2d(3, 5));
- f.set_tensor(2, vec_2d(7, 9));
+ if (multi_node) {
+ set_tensor(1, vec_mixed_2d({{3, 5}, {7, 9}}));
+ } else {
+ set_tensor(1, typed_vec_2d(type, 3, 5));
+ set_tensor(2, typed_vec_2d(type, 7, 9));
+ }
- auto &index_a = f.hnsw_index();
+ auto old_attr = _attr;
+ auto &index_a = hnsw_typed_index<type>();
expect_level_0(2, index_a.get_node(1));
expect_level_0(1, index_a.get_node(2));
- f.save();
+ save();
EXPECT_TRUE(std::filesystem::exists(std::filesystem::path(attr_name + ".nnidx")));
- f.load();
- auto &index_b = f.hnsw_index();
+ load();
+ auto &index_b = hnsw_typed_index<type>();
EXPECT_NOT_EQUAL(&index_a, &index_b);
expect_level_0(2, index_b.get_node(1));
expect_level_0(1, index_b.get_node(2));
}
+class DenseTensorAttributeHnswIndex : public TensorAttributeHnswIndex<HnswIndexType::SINGLE> {
+public:
+ DenseTensorAttributeHnswIndex() : TensorAttributeHnswIndex<HnswIndexType::SINGLE>(vec_2d_spec, FixtureTraits().hnsw()) {}
+};
+
+class MixedTensorAttributeHnswIndex : public TensorAttributeHnswIndex<HnswIndexType::MULTI> {
+public:
+ MixedTensorAttributeHnswIndex() : TensorAttributeHnswIndex<HnswIndexType::MULTI>(vec_mixed_2d_spec, FixtureTraits().mixed_hnsw()) {}
+};
+
+TEST_F("Hnsw index is instantiated in dense tensor attribute when specified in config", DenseTensorAttributeHnswIndex)
+{
+ f.test_setup();
+}
+
+TEST_F("Hnsw index is integrated in dense tensor attribute and can be saved and loaded", DenseTensorAttributeHnswIndex)
+{
+ f.test_save_load(false);
+}
+
+TEST_F("Hnsw index is instantiated in mixed tensor attribute when specified in config", MixedTensorAttributeHnswIndex)
+{
+ f.test_setup();
+}
+
+TEST_F("Hnsw index is integrated in mixed tensor attribute and can be saved and loaded", MixedTensorAttributeHnswIndex)
+{
+ f.test_save_load(false);
+}
+
+TEST_F("Hnsw index is integrated in mixed tensor attribute and can be saved and loaded with multiple points per document", MixedTensorAttributeHnswIndex)
+{
+ f.test_save_load(true);
+}
+
TEST_F("Populates address space usage", DenseTensorAttributeHnswIndex)
{
search::AddressSpaceUsage usage = f._attr->getAddressSpaceUsage();