aboutsummaryrefslogtreecommitdiffstats
path: root/searchlib/src/tests/attribute/tensorattribute/tensorattribute_test.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'searchlib/src/tests/attribute/tensorattribute/tensorattribute_test.cpp')
-rw-r--r--searchlib/src/tests/attribute/tensorattribute/tensorattribute_test.cpp71
1 files changed, 56 insertions, 15 deletions
diff --git a/searchlib/src/tests/attribute/tensorattribute/tensorattribute_test.cpp b/searchlib/src/tests/attribute/tensorattribute/tensorattribute_test.cpp
index c698a1d612b..089a2a8476e 100644
--- a/searchlib/src/tests/attribute/tensorattribute/tensorattribute_test.cpp
+++ b/searchlib/src/tests/attribute/tensorattribute/tensorattribute_test.cpp
@@ -10,6 +10,7 @@
#include <vespa/searchlib/queryeval/nearest_neighbor_blueprint.h>
#include <vespa/searchlib/tensor/default_nearest_neighbor_index_factory.h>
#include <vespa/searchlib/tensor/dense_tensor_attribute.h>
+#include <vespa/searchlib/tensor/direct_tensor_attribute.h>
#include <vespa/searchlib/tensor/doc_vector_access.h>
#include <vespa/searchlib/tensor/generic_tensor_attribute.h>
#include <vespa/searchlib/tensor/hnsw_index.h>
@@ -37,6 +38,7 @@ using search::queryeval::GlobalFilter;
using search::queryeval::NearestNeighborBlueprint;
using search::tensor::DefaultNearestNeighborIndexFactory;
using search::tensor::DenseTensorAttribute;
+using search::tensor::DirectTensorAttribute;
using search::tensor::DocVectorAccess;
using search::tensor::GenericTensorAttribute;
using search::tensor::HnswIndex;
@@ -256,6 +258,40 @@ class MockNearestNeighborIndexFactory : public NearestNeighborIndexFactory {
const vespalib::string test_dir = "test_data/";
const vespalib::string attr_name = test_dir + "my_attr";
+struct FixtureTraits {
+ bool use_dense_tensor_attribute = false;
+ bool use_direct_tensor_attribute = false;
+ bool enable_hnsw_index = false;
+ bool use_mock_index = false;
+
+ FixtureTraits dense() && {
+ use_dense_tensor_attribute = true;
+ enable_hnsw_index = false;
+ return *this;
+ }
+
+ FixtureTraits hnsw() && {
+ use_dense_tensor_attribute = true;
+ enable_hnsw_index = true;
+ use_mock_index = false;
+ return *this;
+ }
+
+ FixtureTraits mock_hnsw() && {
+ use_dense_tensor_attribute = true;
+ enable_hnsw_index = true;
+ use_mock_index = true;
+ return *this;
+ }
+
+ FixtureTraits direct() && {
+ use_dense_tensor_attribute = false;
+ use_direct_tensor_attribute = true;
+ return *this;
+ }
+
+};
+
struct Fixture {
using BasicType = search::attribute::BasicType;
using CollectionType = search::attribute::CollectionType;
@@ -270,24 +306,21 @@ struct Fixture {
std::shared_ptr<TensorAttribute> _tensorAttr;
std::shared_ptr<AttributeVector> _attr;
bool _denseTensors;
- bool _useDenseTensorAttribute;
+ FixtureTraits _traits;
Fixture(const vespalib::string &typeSpec,
- bool useDenseTensorAttribute = false,
- bool enable_hnsw_index = false,
- bool use_mock_index = false)
+ FixtureTraits traits = FixtureTraits())
: _dir_handler(test_dir),
_cfg(BasicType::TENSOR, CollectionType::SINGLE),
_name(attr_name),
_typeSpec(typeSpec),
- _use_mock_index(use_mock_index),
_index_factory(),
_tensorAttr(),
_attr(),
_denseTensors(false),
- _useDenseTensorAttribute(useDenseTensorAttribute)
+ _traits(traits)
{
- if (enable_hnsw_index) {
+ if (traits.enable_hnsw_index) {
_cfg.set_distance_metric(DistanceMetric::Euclidean);
_cfg.set_hnsw_index_params(HnswIndexParams(4, 20, DistanceMetric::Euclidean));
}
@@ -301,7 +334,7 @@ struct Fixture {
if (_cfg.tensorType().is_dense()) {
_denseTensors = true;
}
- if (_use_mock_index) {
+ if (_traits.use_mock_index) {
_index_factory = std::make_unique<MockNearestNeighborIndexFactory>();
} else {
_index_factory = std::make_unique<DefaultNearestNeighborIndexFactory>();
@@ -322,9 +355,11 @@ struct Fixture {
}
std::shared_ptr<TensorAttribute> makeAttr() {
- if (_useDenseTensorAttribute) {
+ if (_traits.use_dense_tensor_attribute) {
assert(_denseTensors);
return std::make_shared<DenseTensorAttribute>(_name, _cfg, *_index_factory);
+ } else if (_traits.use_direct_tensor_attribute) {
+ return std::make_shared<DirectTensorAttribute>(_name, _cfg);
} else {
return std::make_shared<GenericTensorAttribute>(_name, _cfg);
}
@@ -543,7 +578,7 @@ Fixture::testSaveLoad()
void
Fixture::testCompaction()
{
- if (_useDenseTensorAttribute && _denseTensors) {
+ if (_traits.use_dense_tensor_attribute && _denseTensors) {
LOG(info, "Skipping compaction test for tensor '%s' which is using free-lists", _cfg.tensorType().to_spec().c_str());
return;
}
@@ -607,7 +642,7 @@ Fixture::testTensorTypeFileHeaderTag()
auto header = get_file_header();
EXPECT_TRUE(header.hasTag("tensortype"));
EXPECT_EQUAL(_typeSpec, header.getTag("tensortype").asString());
- if (_useDenseTensorAttribute) {
+ if (_traits.use_dense_tensor_attribute) {
EXPECT_EQUAL(1u, header.getTag("version").asInteger());
} else {
EXPECT_EQUAL(0u, header.getTag("version").asInteger());
@@ -644,6 +679,11 @@ TEST("Test sparse tensors with generic tensor attribute")
testAll([]() { return std::make_shared<Fixture>(sparseSpec); });
}
+TEST("Test sparse tensors with direct tensor attribute")
+{
+ testAll([]() { return std::make_shared<Fixture>(sparseSpec, FixtureTraits().direct()); });
+}
+
TEST("Test dense tensors with generic tensor attribute")
{
testAll([]() { return std::make_shared<Fixture>(denseSpec); });
@@ -651,11 +691,11 @@ TEST("Test dense tensors with generic tensor attribute")
TEST("Test dense tensors with dense tensor attribute")
{
- testAll([]() { return std::make_shared<Fixture>(denseSpec, true); });
+ testAll([]() { return std::make_shared<Fixture>(denseSpec, FixtureTraits().dense()); });
}
TEST_F("Hnsw index is NOT instantiated in dense tensor attribute by default",
- Fixture(vec_2d_spec, true, false))
+ Fixture(vec_2d_spec, FixtureTraits().dense()))
{
const auto& tensor = f.as_dense_tensor();
EXPECT_TRUE(tensor.nearest_neighbor_index() == nullptr);
@@ -663,7 +703,7 @@ TEST_F("Hnsw index is NOT instantiated in dense tensor attribute by default",
class DenseTensorAttributeHnswIndex : public Fixture {
public:
- DenseTensorAttributeHnswIndex() : Fixture(vec_2d_spec, true, true, false) {}
+ DenseTensorAttributeHnswIndex() : Fixture(vec_2d_spec, FixtureTraits().hnsw()) {}
};
TEST_F("Hnsw index is instantiated in dense tensor attribute when specified in config", DenseTensorAttributeHnswIndex)
@@ -704,9 +744,10 @@ TEST_F("Hnsw index is integrated in dense tensor attribute and can be saved and
expect_level_0(1, index_b.get_node(2));
}
+
class DenseTensorAttributeMockIndex : public Fixture {
public:
- DenseTensorAttributeMockIndex() : Fixture(vec_2d_spec, true, true, true) {}
+ DenseTensorAttributeMockIndex() : Fixture(vec_2d_spec, FixtureTraits().mock_hnsw()) {}
};
TEST_F("setTensor() updates nearest neighbor index", DenseTensorAttributeMockIndex)