diff options
Diffstat (limited to 'searchlib/src/tests/attribute/tensorattribute/tensorattribute_test.cpp')
-rw-r--r-- | searchlib/src/tests/attribute/tensorattribute/tensorattribute_test.cpp | 71 |
1 files changed, 56 insertions, 15 deletions
diff --git a/searchlib/src/tests/attribute/tensorattribute/tensorattribute_test.cpp b/searchlib/src/tests/attribute/tensorattribute/tensorattribute_test.cpp index c698a1d612b..089a2a8476e 100644 --- a/searchlib/src/tests/attribute/tensorattribute/tensorattribute_test.cpp +++ b/searchlib/src/tests/attribute/tensorattribute/tensorattribute_test.cpp @@ -10,6 +10,7 @@ #include <vespa/searchlib/queryeval/nearest_neighbor_blueprint.h> #include <vespa/searchlib/tensor/default_nearest_neighbor_index_factory.h> #include <vespa/searchlib/tensor/dense_tensor_attribute.h> +#include <vespa/searchlib/tensor/direct_tensor_attribute.h> #include <vespa/searchlib/tensor/doc_vector_access.h> #include <vespa/searchlib/tensor/generic_tensor_attribute.h> #include <vespa/searchlib/tensor/hnsw_index.h> @@ -37,6 +38,7 @@ using search::queryeval::GlobalFilter; using search::queryeval::NearestNeighborBlueprint; using search::tensor::DefaultNearestNeighborIndexFactory; using search::tensor::DenseTensorAttribute; +using search::tensor::DirectTensorAttribute; using search::tensor::DocVectorAccess; using search::tensor::GenericTensorAttribute; using search::tensor::HnswIndex; @@ -256,6 +258,40 @@ class MockNearestNeighborIndexFactory : public NearestNeighborIndexFactory { const vespalib::string test_dir = "test_data/"; const vespalib::string attr_name = test_dir + "my_attr"; +struct FixtureTraits { + bool use_dense_tensor_attribute = false; + bool use_direct_tensor_attribute = false; + bool enable_hnsw_index = false; + bool use_mock_index = false; + + FixtureTraits dense() && { + use_dense_tensor_attribute = true; + enable_hnsw_index = false; + return *this; + } + + FixtureTraits hnsw() && { + use_dense_tensor_attribute = true; + enable_hnsw_index = true; + use_mock_index = false; + return *this; + } + + FixtureTraits mock_hnsw() && { + use_dense_tensor_attribute = true; + enable_hnsw_index = true; + use_mock_index = true; + return *this; + } + + FixtureTraits direct() && { + use_dense_tensor_attribute = false; + use_direct_tensor_attribute = true; + return *this; + } + +}; + struct Fixture { using BasicType = search::attribute::BasicType; using CollectionType = search::attribute::CollectionType; @@ -270,24 +306,21 @@ struct Fixture { std::shared_ptr<TensorAttribute> _tensorAttr; std::shared_ptr<AttributeVector> _attr; bool _denseTensors; - bool _useDenseTensorAttribute; + FixtureTraits _traits; Fixture(const vespalib::string &typeSpec, - bool useDenseTensorAttribute = false, - bool enable_hnsw_index = false, - bool use_mock_index = false) + FixtureTraits traits = FixtureTraits()) : _dir_handler(test_dir), _cfg(BasicType::TENSOR, CollectionType::SINGLE), _name(attr_name), _typeSpec(typeSpec), - _use_mock_index(use_mock_index), _index_factory(), _tensorAttr(), _attr(), _denseTensors(false), - _useDenseTensorAttribute(useDenseTensorAttribute) + _traits(traits) { - if (enable_hnsw_index) { + if (traits.enable_hnsw_index) { _cfg.set_distance_metric(DistanceMetric::Euclidean); _cfg.set_hnsw_index_params(HnswIndexParams(4, 20, DistanceMetric::Euclidean)); } @@ -301,7 +334,7 @@ struct Fixture { if (_cfg.tensorType().is_dense()) { _denseTensors = true; } - if (_use_mock_index) { + if (_traits.use_mock_index) { _index_factory = std::make_unique<MockNearestNeighborIndexFactory>(); } else { _index_factory = std::make_unique<DefaultNearestNeighborIndexFactory>(); @@ -322,9 +355,11 @@ struct Fixture { } std::shared_ptr<TensorAttribute> makeAttr() { - if (_useDenseTensorAttribute) { + if (_traits.use_dense_tensor_attribute) { assert(_denseTensors); return std::make_shared<DenseTensorAttribute>(_name, _cfg, *_index_factory); + } else if (_traits.use_direct_tensor_attribute) { + return std::make_shared<DirectTensorAttribute>(_name, _cfg); } else { return std::make_shared<GenericTensorAttribute>(_name, _cfg); } @@ -543,7 +578,7 @@ Fixture::testSaveLoad() void Fixture::testCompaction() { - if (_useDenseTensorAttribute && _denseTensors) { + if (_traits.use_dense_tensor_attribute && _denseTensors) { LOG(info, "Skipping compaction test for tensor '%s' which is using free-lists", _cfg.tensorType().to_spec().c_str()); return; } @@ -607,7 +642,7 @@ Fixture::testTensorTypeFileHeaderTag() auto header = get_file_header(); EXPECT_TRUE(header.hasTag("tensortype")); EXPECT_EQUAL(_typeSpec, header.getTag("tensortype").asString()); - if (_useDenseTensorAttribute) { + if (_traits.use_dense_tensor_attribute) { EXPECT_EQUAL(1u, header.getTag("version").asInteger()); } else { EXPECT_EQUAL(0u, header.getTag("version").asInteger()); @@ -644,6 +679,11 @@ TEST("Test sparse tensors with generic tensor attribute") testAll([]() { return std::make_shared<Fixture>(sparseSpec); }); } +TEST("Test sparse tensors with direct tensor attribute") +{ + testAll([]() { return std::make_shared<Fixture>(sparseSpec, FixtureTraits().direct()); }); +} + TEST("Test dense tensors with generic tensor attribute") { testAll([]() { return std::make_shared<Fixture>(denseSpec); }); @@ -651,11 +691,11 @@ TEST("Test dense tensors with generic tensor attribute") TEST("Test dense tensors with dense tensor attribute") { - testAll([]() { return std::make_shared<Fixture>(denseSpec, true); }); + testAll([]() { return std::make_shared<Fixture>(denseSpec, FixtureTraits().dense()); }); } TEST_F("Hnsw index is NOT instantiated in dense tensor attribute by default", - Fixture(vec_2d_spec, true, false)) + Fixture(vec_2d_spec, FixtureTraits().dense())) { const auto& tensor = f.as_dense_tensor(); EXPECT_TRUE(tensor.nearest_neighbor_index() == nullptr); @@ -663,7 +703,7 @@ TEST_F("Hnsw index is NOT instantiated in dense tensor attribute by default", class DenseTensorAttributeHnswIndex : public Fixture { public: - DenseTensorAttributeHnswIndex() : Fixture(vec_2d_spec, true, true, false) {} + DenseTensorAttributeHnswIndex() : Fixture(vec_2d_spec, FixtureTraits().hnsw()) {} }; TEST_F("Hnsw index is instantiated in dense tensor attribute when specified in config", DenseTensorAttributeHnswIndex) @@ -704,9 +744,10 @@ TEST_F("Hnsw index is integrated in dense tensor attribute and can be saved and expect_level_0(1, index_b.get_node(2)); } + class DenseTensorAttributeMockIndex : public Fixture { public: - DenseTensorAttributeMockIndex() : Fixture(vec_2d_spec, true, true, true) {} + DenseTensorAttributeMockIndex() : Fixture(vec_2d_spec, FixtureTraits().mock_hnsw()) {} }; TEST_F("setTensor() updates nearest neighbor index", DenseTensorAttributeMockIndex) |