summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorGeir Storli <geirst@verizonmedia.com>2020-02-20 15:15:22 +0100
committerGitHub <noreply@github.com>2020-02-20 15:15:22 +0100
commit8b9ddc1fea064f2851f540b9fdeff94d12c8ffa4 (patch)
tree94f8270a8a21d066ebd0b67e48272ab41ea0e60a
parentf9661fed45e23cd170c77c668e30852f2f42d5b4 (diff)
parente59c8adf957ee3147ce6e8d90757c5882c494fb2 (diff)
Merge pull request #12287 from vespa-engine/geirst/integrate-nearest-neighbor-index-in-dense-tensor-attribute
Integrate nearest neighbor index in dense tensor attribute
-rw-r--r--config-model/src/main/java/com/yahoo/searchdefinition/processing/TensorFieldProcessor.java18
-rw-r--r--searchcommon/src/vespa/searchcommon/attribute/config.cpp9
-rw-r--r--searchcommon/src/vespa/searchcommon/attribute/config.h18
-rw-r--r--searchcommon/src/vespa/searchcommon/attribute/hnsw_index_params.h32
-rw-r--r--searchlib/src/tests/attribute/attributemanager/attributemanager_test.cpp16
-rw-r--r--searchlib/src/tests/attribute/tensorattribute/tensorattribute_test.cpp332
-rw-r--r--searchlib/src/tests/tensor/hnsw_index/hnsw_index_test.cpp10
-rw-r--r--searchlib/src/vespa/searchlib/attribute/configconverter.cpp4
-rw-r--r--searchlib/src/vespa/searchlib/tensor/CMakeLists.txt1
-rw-r--r--searchlib/src/vespa/searchlib/tensor/default_nearest_neighbor_index_factory.cpp51
-rw-r--r--searchlib/src/vespa/searchlib/tensor/default_nearest_neighbor_index_factory.h19
-rw-r--r--searchlib/src/vespa/searchlib/tensor/dense_tensor_attribute.cpp43
-rw-r--r--searchlib/src/vespa/searchlib/tensor/dense_tensor_attribute.h48
-rw-r--r--searchlib/src/vespa/searchlib/tensor/distance_function.h3
-rw-r--r--searchlib/src/vespa/searchlib/tensor/distance_functions.h1
-rw-r--r--searchlib/src/vespa/searchlib/tensor/hnsw_index.cpp12
-rw-r--r--searchlib/src/vespa/searchlib/tensor/hnsw_index.h15
-rw-r--r--searchlib/src/vespa/searchlib/tensor/nearest_neighbor_index_factory.h26
-rw-r--r--searchlib/src/vespa/searchlib/tensor/random_level_generator.h3
19 files changed, 530 insertions, 131 deletions
diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/processing/TensorFieldProcessor.java b/config-model/src/main/java/com/yahoo/searchdefinition/processing/TensorFieldProcessor.java
index 9cd7fb24e42..2790f2ddf6e 100644
--- a/config-model/src/main/java/com/yahoo/searchdefinition/processing/TensorFieldProcessor.java
+++ b/config-model/src/main/java/com/yahoo/searchdefinition/processing/TensorFieldProcessor.java
@@ -24,16 +24,18 @@ public class TensorFieldProcessor extends Processor {
@Override
public void process(boolean validate, boolean documentsOnly) {
- if ( ! validate) return;
-
for (var field : search.allConcreteFields()) {
if ( field.getDataType() instanceof TensorDataType ) {
- validateIndexingScripsForTensorField(field);
- validateAttributeSettingForTensorField(field);
- processIndexSettingsForTensorField(field);
+ if (validate) {
+ validateIndexingScripsForTensorField(field);
+ validateAttributeSettingForTensorField(field);
+ }
+ processIndexSettingsForTensorField(field, validate);
}
else if (field.getDataType() instanceof CollectionDataType){
- validateDataTypeForCollectionField(field);
+ if (validate) {
+ validateDataTypeForCollectionField(field);
+ }
}
}
}
@@ -68,12 +70,12 @@ public class TensorFieldProcessor extends Processor {
}
}
- private void processIndexSettingsForTensorField(SDField field) {
+ private void processIndexSettingsForTensorField(SDField field, boolean validate) {
if (!field.doesIndexing()) {
return;
}
if (isTensorTypeThatSupportsHnswIndex(field)) {
- if (!field.doesAttributing()) {
+ if (validate && !field.doesAttributing()) {
fail(search, field, "A tensor that has an index must also be an attribute.");
}
var index = field.getIndex(field.getName());
diff --git a/searchcommon/src/vespa/searchcommon/attribute/config.cpp b/searchcommon/src/vespa/searchcommon/attribute/config.cpp
index 53e57fd9c66..b4e05875820 100644
--- a/searchcommon/src/vespa/searchcommon/attribute/config.cpp
+++ b/searchcommon/src/vespa/searchcommon/attribute/config.cpp
@@ -17,7 +17,8 @@ Config::Config() :
_growStrategy(),
_compactionStrategy(),
_predicateParams(),
- _tensorType(vespalib::eval::ValueType::error_type())
+ _tensorType(vespalib::eval::ValueType::error_type()),
+ _hnsw_index_params()
{
}
@@ -34,7 +35,8 @@ Config::Config(BasicType bt, CollectionType ct, bool fastSearch_, bool huge_)
_growStrategy(),
_compactionStrategy(),
_predicateParams(),
- _tensorType(vespalib::eval::ValueType::error_type())
+ _tensorType(vespalib::eval::ValueType::error_type()),
+ _hnsw_index_params()
{
}
@@ -60,7 +62,8 @@ Config::operator==(const Config &b) const
_compactionStrategy == b._compactionStrategy &&
_predicateParams == b._predicateParams &&
(_basicType.type() != BasicType::Type::TENSOR ||
- _tensorType == b._tensorType);
+ _tensorType == b._tensorType) &&
+ _hnsw_index_params == b._hnsw_index_params;
}
}
diff --git a/searchcommon/src/vespa/searchcommon/attribute/config.h b/searchcommon/src/vespa/searchcommon/attribute/config.h
index 2f767061f7a..836fcfed84a 100644
--- a/searchcommon/src/vespa/searchcommon/attribute/config.h
+++ b/searchcommon/src/vespa/searchcommon/attribute/config.h
@@ -4,15 +4,21 @@
#include "basictype.h"
#include "collectiontype.h"
+#include "hnsw_index_params.h"
#include "predicate_params.h"
-#include <vespa/searchcommon/common/growstrategy.h>
#include <vespa/searchcommon/common/compaction_strategy.h>
+#include <vespa/searchcommon/common/growstrategy.h>
#include <vespa/eval/eval/value_type.h>
+#include <optional>
namespace search::attribute {
-class Config
-{
+/**
+ * Configuration for an attribute vector.
+ *
+ * Used to determine which implementation to instantiate.
+ */
+class Config {
public:
Config();
Config(BasicType bt, CollectionType ct = CollectionType::SINGLE,
@@ -29,6 +35,7 @@ public:
bool huge() const { return _huge; }
const PredicateParams &predicateParams() const { return _predicateParams; }
vespalib::eval::ValueType tensorType() const { return _tensorType; }
+ const std::optional<HnswIndexParams>& hnsw_index_params() const { return _hnsw_index_params; }
/**
* Check if attribute posting list can consist of a bitvector in
@@ -60,6 +67,10 @@ public:
_tensorType = tensorType_in;
return *this;
}
+ Config& set_hnsw_index_params(const HnswIndexParams& params) {
+ _hnsw_index_params = params;
+ return *this;
+ }
/**
* Enable attribute posting list to consist of a bitvector in
@@ -107,6 +118,7 @@ private:
CompactionStrategy _compactionStrategy;
PredicateParams _predicateParams;
vespalib::eval::ValueType _tensorType;
+ std::optional<HnswIndexParams> _hnsw_index_params;
};
}
diff --git a/searchcommon/src/vespa/searchcommon/attribute/hnsw_index_params.h b/searchcommon/src/vespa/searchcommon/attribute/hnsw_index_params.h
new file mode 100644
index 00000000000..9e98a8c5fb7
--- /dev/null
+++ b/searchcommon/src/vespa/searchcommon/attribute/hnsw_index_params.h
@@ -0,0 +1,32 @@
+// Copyright 2020 Oath Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+
+#pragma once
+
+namespace search::attribute {
+
+/**
+ * Configuration parameters for a hnsw index used together with a 1-dimensional indexed tensor
+ * for approximate nearest neighbor search.
+ */
+class HnswIndexParams {
+private:
+ uint32_t _max_links_per_node;
+ uint32_t _neighbors_to_explore_at_insert;
+
+public:
+ HnswIndexParams(uint32_t max_links_per_node_in,
+ uint32_t neighbors_to_explore_at_insert_in)
+ : _max_links_per_node(max_links_per_node_in),
+ _neighbors_to_explore_at_insert(neighbors_to_explore_at_insert_in)
+ {}
+
+ uint32_t max_links_per_node() const { return _max_links_per_node; }
+ uint32_t neighbors_to_explore_at_insert() const { return _neighbors_to_explore_at_insert; }
+
+ bool operator==(const HnswIndexParams& rhs) const {
+ return _max_links_per_node == rhs._max_links_per_node &&
+ _neighbors_to_explore_at_insert == rhs._neighbors_to_explore_at_insert;
+ }
+};
+
+}
diff --git a/searchlib/src/tests/attribute/attributemanager/attributemanager_test.cpp b/searchlib/src/tests/attribute/attributemanager/attributemanager_test.cpp
index 7d09b2aa0b8..850a967ed3d 100644
--- a/searchlib/src/tests/attribute/attributemanager/attributemanager_test.cpp
+++ b/searchlib/src/tests/attribute/attributemanager/attributemanager_test.cpp
@@ -278,6 +278,22 @@ AttributeManagerTest::testConfigConvert()
AttributeVector::Config out = ConfigConverter::convert(a);
EXPECT_EQUAL("tensor(x[5])", out.tensorType().to_spec());
}
+ { // hnsw index params (enabled)
+ CACA a;
+ a.index.hnsw.enabled = true;
+ a.index.hnsw.maxlinkspernode = 32;
+ a.index.hnsw.neighborstoexploreatinsert = 300;
+ auto out = ConfigConverter::convert(a);
+ EXPECT_TRUE(out.hnsw_index_params().has_value());
+ EXPECT_EQUAL(32u, out.hnsw_index_params().value().max_links_per_node());
+ EXPECT_EQUAL(300u, out.hnsw_index_params().value().neighbors_to_explore_at_insert());
+ }
+ { // hnsw index params (disabled)
+ CACA a;
+ a.index.hnsw.enabled = false;
+ auto out = ConfigConverter::convert(a);
+ EXPECT_FALSE(out.hnsw_index_params().has_value());
+ }
}
bool gt_attribute(const attribute::IAttributeVector * a, const attribute::IAttributeVector * b) {
diff --git a/searchlib/src/tests/attribute/tensorattribute/tensorattribute_test.cpp b/searchlib/src/tests/attribute/tensorattribute/tensorattribute_test.cpp
index 7e0fcdc0ccc..644230cb340 100644
--- a/searchlib/src/tests/attribute/tensorattribute/tensorattribute_test.cpp
+++ b/searchlib/src/tests/attribute/tensorattribute/tensorattribute_test.cpp
@@ -1,34 +1,48 @@
// Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
-#include <vespa/vespalib/testkit/test_kit.h>
#include <vespa/document/base/exceptions.h>
-#include <vespa/searchlib/tensor/tensor_attribute.h>
-#include <vespa/searchlib/tensor/generic_tensor_attribute.h>
-#include <vespa/searchlib/tensor/dense_tensor_attribute.h>
-#include <vespa/searchlib/attribute/attributeguard.h>
-#include <vespa/eval/tensor/tensor.h>
-#include <vespa/eval/tensor/dense/dense_tensor.h>
#include <vespa/eval/tensor/default_tensor_engine.h>
-#include <vespa/vespalib/io/fileutil.h>
-#include <vespa/vespalib/data/fileheader.h>
+#include <vespa/eval/tensor/dense/dense_tensor.h>
+#include <vespa/eval/tensor/tensor.h>
#include <vespa/fastos/file.h>
+#include <vespa/searchlib/attribute/attributeguard.h>
+#include <vespa/searchlib/tensor/default_nearest_neighbor_index_factory.h>
+#include <vespa/searchlib/tensor/dense_tensor_attribute.h>
+#include <vespa/searchlib/tensor/doc_vector_access.h>
+#include <vespa/searchlib/tensor/generic_tensor_attribute.h>
+#include <vespa/searchlib/tensor/hnsw_index.h>
+#include <vespa/searchlib/tensor/nearest_neighbor_index.h>
+#include <vespa/searchlib/tensor/nearest_neighbor_index_factory.h>
+#include <vespa/searchlib/tensor/tensor_attribute.h>
+#include <vespa/vespalib/data/fileheader.h>
+#include <vespa/vespalib/io/fileutil.h>
+#include <vespa/vespalib/test/insertion_operators.h>
+#include <vespa/vespalib/testkit/test_kit.h>
+
#include <vespa/log/log.h>
LOG_SETUP("tensorattribute_test");
using document::WrongTensorTypeException;
-using search::tensor::TensorAttribute;
-using search::tensor::DenseTensorAttribute;
-using search::tensor::GenericTensorAttribute;
using search::AttributeGuard;
using search::AttributeVector;
-using vespalib::eval::ValueType;
+using search::attribute::HnswIndexParams;
+using search::tensor::DefaultNearestNeighborIndexFactory;
+using search::tensor::DenseTensorAttribute;
+using search::tensor::DocVectorAccess;
+using search::tensor::GenericTensorAttribute;
+using search::tensor::HnswIndex;
+using search::tensor::NearestNeighborIndex;
+using search::tensor::NearestNeighborIndexFactory;
+using search::tensor::TensorAttribute;
using vespalib::eval::TensorSpec;
-using vespalib::tensor::Tensor;
-using vespalib::tensor::DenseTensor;
+using vespalib::eval::ValueType;
using vespalib::tensor::DefaultTensorEngine;
+using vespalib::tensor::DenseTensor;
+using vespalib::tensor::Tensor;
+
+using DoubleVector = std::vector<double>;
-namespace vespalib {
-namespace tensor {
+namespace vespalib::tensor {
static bool operator==(const Tensor &lhs, const Tensor &rhs)
{
@@ -36,10 +50,10 @@ static bool operator==(const Tensor &lhs, const Tensor &rhs)
}
}
-}
vespalib::string sparseSpec("tensor(x{},y{})");
vespalib::string denseSpec("tensor(x[2],y[3])");
+vespalib::string vec_2d_spec("tensor(x[2])");
Tensor::UP createTensor(const TensorSpec &spec) {
auto value = DefaultTensorEngine::ref().from_spec(spec);
@@ -52,6 +66,78 @@ Tensor::UP createTensor(const TensorSpec &spec) {
return Tensor::UP(tensor);
}
+TensorSpec
+vec_2d(double x0, double x1)
+{
+ return TensorSpec(vec_2d_spec).add({{"x", 0}}, x0).add({{"x", 1}}, x1);
+}
+
+class MockNearestNeighborIndex : public NearestNeighborIndex {
+private:
+ using Entry = std::pair<uint32_t, DoubleVector>;
+ using EntryVector = std::vector<Entry>;
+
+ const DocVectorAccess& _vectors;
+ EntryVector _adds;
+ EntryVector _removes;
+
+public:
+ MockNearestNeighborIndex(const DocVectorAccess& vectors)
+ : _vectors(vectors),
+ _adds(),
+ _removes()
+ {
+ }
+ void clear() {
+ _adds.clear();
+ _removes.clear();
+ }
+ void expect_empty_add() const {
+ EXPECT_TRUE(_adds.empty());
+ }
+ void expect_add(uint32_t exp_docid, const DoubleVector& exp_vector) const {
+ EXPECT_EQUAL(1u, _adds.size());
+ EXPECT_EQUAL(exp_docid, _adds.back().first);
+ EXPECT_EQUAL(exp_vector, _adds.back().second);
+ }
+ void expect_adds(const EntryVector &exp_adds) const {
+ EXPECT_EQUAL(exp_adds, _adds);
+ }
+ void expect_empty_remove() const {
+ EXPECT_TRUE(_removes.empty());
+ }
+ void expect_remove(uint32_t exp_docid, const DoubleVector& exp_vector) const {
+ EXPECT_EQUAL(1u, _removes.size());
+ EXPECT_EQUAL(exp_docid, _removes.back().first);
+ EXPECT_EQUAL(exp_vector, _removes.back().second);
+ }
+ void add_document(uint32_t docid) override {
+ auto vector = _vectors.get_vector(docid).typify<double>();
+ _adds.emplace_back(docid, DoubleVector(vector.begin(), vector.end()));
+ }
+ void remove_document(uint32_t docid) override {
+ auto vector = _vectors.get_vector(docid).typify<double>();
+ _removes.emplace_back(docid, DoubleVector(vector.begin(), vector.end()));
+ }
+ std::vector<uint32_t> find_top_k(uint32_t k, vespalib::tensor::TypedCells vector, uint32_t explore_k) override {
+ (void) k;
+ (void) vector;
+ (void) explore_k;
+ return std::vector<uint32_t>();
+ }
+};
+
+class MockNearestNeighborIndexFactory : public NearestNeighborIndexFactory {
+
+ std::unique_ptr<NearestNeighborIndex> make(const DocVectorAccess& vectors,
+ ValueType::CellType cell_type,
+ const search::attribute::HnswIndexParams& params) const override {
+ (void) params;
+ assert(cell_type == ValueType::CellType::DOUBLE);
+ return std::make_unique<MockNearestNeighborIndex>(vectors);
+ }
+};
+
struct Fixture
{
using BasicType = search::attribute::BasicType;
@@ -61,16 +147,20 @@ struct Fixture
Config _cfg;
vespalib::string _name;
vespalib::string _typeSpec;
+ std::unique_ptr<NearestNeighborIndexFactory> _index_factory;
std::shared_ptr<TensorAttribute> _tensorAttr;
std::shared_ptr<AttributeVector> _attr;
bool _denseTensors;
bool _useDenseTensorAttribute;
Fixture(const vespalib::string &typeSpec,
- bool useDenseTensorAttribute = false)
+ bool useDenseTensorAttribute = false,
+ bool enable_hnsw_index = false,
+ bool use_mock_index = false)
: _cfg(BasicType::TENSOR, CollectionType::SINGLE),
_name("test"),
_typeSpec(typeSpec),
+ _index_factory(std::make_unique<DefaultNearestNeighborIndexFactory>()),
_tensorAttr(),
_attr(),
_denseTensors(false),
@@ -80,20 +170,40 @@ struct Fixture
if (_cfg.tensorType().is_dense()) {
_denseTensors = true;
}
+ if (enable_hnsw_index) {
+ _cfg.set_hnsw_index_params(HnswIndexParams(4, 20));
+ if (use_mock_index) {
+ _index_factory = std::make_unique<MockNearestNeighborIndexFactory>();
+ }
+ }
_tensorAttr = makeAttr();
_attr = _tensorAttr;
_attr->addReservedDoc();
}
+ ~Fixture() {}
std::shared_ptr<TensorAttribute> makeAttr() {
if (_useDenseTensorAttribute) {
assert(_denseTensors);
- return std::make_shared<DenseTensorAttribute>(_name, _cfg);
+ return std::make_shared<DenseTensorAttribute>(_name, _cfg, *_index_factory);
} else {
return std::make_shared<GenericTensorAttribute>(_name, _cfg);
}
}
+ const DenseTensorAttribute& as_dense_tensor() const {
+ auto result = dynamic_cast<const DenseTensorAttribute*>(_tensorAttr.get());
+ assert(result != nullptr);
+ return *result;
+ }
+
+ MockNearestNeighborIndex& mock_index() {
+ assert(as_dense_tensor().nearest_neighbor_index() != nullptr);
+ auto mock_index = dynamic_cast<const MockNearestNeighborIndex*>(as_dense_tensor().nearest_neighbor_index());
+ assert(mock_index != nullptr);
+ return *const_cast<MockNearestNeighborIndex*>(mock_index);
+ }
+
void ensureSpace(uint32_t docId) {
while (_attr->getNumDocs() <= docId) {
uint32_t newDocId = 0u;
@@ -108,7 +218,15 @@ struct Fixture
_attr->commit();
}
- void setTensor(uint32_t docId, const Tensor &tensor) {
+ void set_tensor(uint32_t docid, const TensorSpec &spec) {
+ set_tensor_internal(docid, *createTensor(spec));
+ }
+
+ void set_empty_tensor(uint32_t docid) {
+ set_tensor_internal(docid, *_tensorAttr->getEmptyTensor());
+ }
+
+ void set_tensor_internal(uint32_t docId, const Tensor &tensor) {
ensureSpace(docId);
_tensorAttr->setTensor(docId, tensor);
_attr->commit();
@@ -119,27 +237,18 @@ struct Fixture
return _attr->getStatus();
}
- void
- assertGetNoTensor(uint32_t docId) {
+ void assertGetNoTensor(uint32_t docId) {
AttributeGuard guard(_attr);
Tensor::UP actTensor = _tensorAttr->getTensor(docId);
EXPECT_FALSE(actTensor);
}
- void
- assertGetTensor(const Tensor &expTensor, uint32_t docId)
- {
+ void assertGetTensor(const TensorSpec &expSpec, uint32_t docId) {
+ Tensor::UP expTensor = createTensor(expSpec);
AttributeGuard guard(_attr);
Tensor::UP actTensor = _tensorAttr->getTensor(docId);
EXPECT_TRUE(static_cast<bool>(actTensor));
- EXPECT_EQUAL(expTensor, *actTensor);
- }
-
- void
- assertGetTensor(const TensorSpec &expSpec, uint32_t docId)
- {
- Tensor::UP expTensor = createTensor(expSpec);
- assertGetTensor(*expTensor, docId);
+ EXPECT_EQUAL(*expTensor, *actTensor);
}
void save() {
@@ -154,23 +263,20 @@ struct Fixture
EXPECT_TRUE(loadok);
}
- Tensor::UP expDenseTensor3() const
- {
- return createTensor(TensorSpec(denseSpec)
- .add({{"x", 0}, {"y", 1}}, 11)
- .add({{"x", 1}, {"y", 2}}, 0));
+ TensorSpec expDenseTensor3() const {
+ return TensorSpec(denseSpec)
+ .add({{"x", 0}, {"y", 1}}, 11)
+ .add({{"x", 1}, {"y", 2}}, 0);
}
- Tensor::UP expDenseFillTensor() const
- {
- return createTensor(TensorSpec(denseSpec)
- .add({{"x", 0}, {"y", 0}}, 5)
- .add({{"x", 1}, {"y", 2}}, 0));
+ TensorSpec expDenseFillTensor() const {
+ return TensorSpec(denseSpec)
+ .add({{"x", 0}, {"y", 0}}, 5)
+ .add({{"x", 1}, {"y", 2}}, 0);
}
- Tensor::UP expEmptyDenseTensor() const
- {
- return createTensor(TensorSpec(denseSpec));
+ TensorSpec expEmptyDenseTensor() const {
+ return TensorSpec(denseSpec);
}
vespalib::string expEmptyDenseTensorSpec() const {
@@ -200,21 +306,21 @@ Fixture::testSetTensorValue()
EXPECT_EQUAL(5u, _attr->getNumDocs());
EXPECT_EQUAL(5u, _attr->getCommittedDocIdLimit());
TEST_DO(assertGetNoTensor(4));
- EXPECT_EXCEPTION(setTensor(4, *createTensor(TensorSpec("double"))),
+ EXPECT_EXCEPTION(set_tensor(4, TensorSpec("double")),
WrongTensorTypeException,
"but other tensor type is 'double'");
TEST_DO(assertGetNoTensor(4));
- setTensor(4, *_tensorAttr->getEmptyTensor());
+ set_empty_tensor(4);
if (_denseTensors) {
- TEST_DO(assertGetTensor(*expEmptyDenseTensor(), 4));
- setTensor(3, *expDenseTensor3());
- TEST_DO(assertGetTensor(*expDenseTensor3(), 3));
+ TEST_DO(assertGetTensor(expEmptyDenseTensor(), 4));
+ set_tensor(3, expDenseTensor3());
+ TEST_DO(assertGetTensor(expDenseTensor3(), 3));
} else {
TEST_DO(assertGetTensor(TensorSpec(sparseSpec), 4));
- setTensor(3, *createTensor(TensorSpec(sparseSpec)
- .add({{"x", ""}, {"y", ""}}, 11)));
+ set_tensor(3, TensorSpec(sparseSpec)
+ .add({{"x", ""}, {"y", ""}}, 11));
TEST_DO(assertGetTensor(TensorSpec(sparseSpec)
- .add({{"x", ""}, {"y", ""}}, 11), 3));
+ .add({{"x", ""}, {"y", ""}}, 11), 3));
}
TEST_DO(assertGetNoTensor(2));
TEST_DO(clearTensor(3));
@@ -225,23 +331,23 @@ void
Fixture::testSaveLoad()
{
ensureSpace(4);
- setTensor(4, *_tensorAttr->getEmptyTensor());
+ set_empty_tensor(4);
if (_denseTensors) {
- setTensor(3, *expDenseTensor3());
+ set_tensor(3, expDenseTensor3());
} else {
- setTensor(3, *createTensor(TensorSpec(sparseSpec)
- .add({{"x", ""}, {"y", "1"}}, 11)));
+ set_tensor(3, TensorSpec(sparseSpec)
+ .add({{"x", ""}, {"y", "1"}}, 11));
}
TEST_DO(save());
TEST_DO(load());
EXPECT_EQUAL(5u, _attr->getNumDocs());
EXPECT_EQUAL(5u, _attr->getCommittedDocIdLimit());
if (_denseTensors) {
- TEST_DO(assertGetTensor(*expDenseTensor3(), 3));
- TEST_DO(assertGetTensor(*expEmptyDenseTensor(), 4));
+ TEST_DO(assertGetTensor(expDenseTensor3(), 3));
+ TEST_DO(assertGetTensor(expEmptyDenseTensor(), 4));
} else {
TEST_DO(assertGetTensor(TensorSpec(sparseSpec)
- .add({{"x", ""}, {"y", "1"}}, 11), 3));
+ .add({{"x", ""}, {"y", "1"}}, 11), 3));
TEST_DO(assertGetTensor(TensorSpec(sparseSpec), 4));
}
TEST_DO(assertGetNoTensor(2));
@@ -256,29 +362,28 @@ Fixture::testCompaction()
return;
}
ensureSpace(4);
- Tensor::UP emptytensor = _tensorAttr->getEmptyTensor();
- Tensor::UP emptyxytensor = createTensor(TensorSpec(sparseSpec));
- Tensor::UP simpletensor = createTensor(TensorSpec(sparseSpec)
- .add({{"x", ""}, {"y", "1"}}, 11));
- Tensor::UP filltensor = createTensor(TensorSpec(sparseSpec)
- .add({{"x", ""}, {"y", ""}}, 5));
+ TensorSpec empty_xy_tensor(sparseSpec);
+ TensorSpec simple_tensor = TensorSpec(sparseSpec)
+ .add({{"x", ""}, {"y", "1"}}, 11);
+ TensorSpec fill_tensor = TensorSpec(sparseSpec)
+ .add({{"x", ""}, {"y", ""}}, 5);
if (_denseTensors) {
- emptyxytensor = expEmptyDenseTensor();
- simpletensor = expDenseTensor3();
- filltensor = expDenseFillTensor();
+ empty_xy_tensor = expEmptyDenseTensor();
+ simple_tensor = expDenseTensor3();
+ fill_tensor = expDenseFillTensor();
}
- setTensor(4, *emptytensor);
- setTensor(3, *simpletensor);
- setTensor(2, *filltensor);
+ set_empty_tensor(4);
+ set_tensor(3, simple_tensor);
+ set_tensor(2, fill_tensor);
clearTensor(2);
- setTensor(2, *filltensor);
+ set_tensor(2, fill_tensor);
search::attribute::Status oldStatus = getStatus();
search::attribute::Status newStatus = oldStatus;
uint64_t iter = 0;
uint64_t iterLimit = 100000;
for (; iter < iterLimit; ++iter) {
clearTensor(2);
- setTensor(2, *filltensor);
+ set_tensor(2, fill_tensor);
newStatus = getStatus();
if (newStatus.getUsed() < oldStatus.getUsed()) {
break;
@@ -290,9 +395,9 @@ Fixture::testCompaction()
"iter = %" PRIu64 ", memory usage %" PRIu64 ", -> %" PRIu64,
iter, oldStatus.getUsed(), newStatus.getUsed());
TEST_DO(assertGetNoTensor(1));
- TEST_DO(assertGetTensor(*filltensor, 2));
- TEST_DO(assertGetTensor(*simpletensor, 3));
- TEST_DO(assertGetTensor(*emptyxytensor, 4));
+ TEST_DO(assertGetTensor(fill_tensor, 2));
+ TEST_DO(assertGetTensor(simple_tensor, 3));
+ TEST_DO(assertGetTensor(empty_xy_tensor, 4));
}
void
@@ -357,4 +462,73 @@ TEST("Test dense tensors with dense tensor attribute")
testAll([]() { return std::make_shared<Fixture>(denseSpec, true); });
}
+TEST_F("Hnsw index is NOT instantiated in dense tensor attribute by default",
+ Fixture(vec_2d_spec, true, false))
+{
+ const auto& tensor = f.as_dense_tensor();
+ EXPECT_TRUE(tensor.nearest_neighbor_index() == nullptr);
+}
+
+TEST_F("Hnsw index is instantiated in dense tensor attribute when specified in config",
+ Fixture(vec_2d_spec, true, true))
+{
+ const auto& tensor = f.as_dense_tensor();
+ ASSERT_TRUE(tensor.nearest_neighbor_index() != nullptr);
+ auto hnsw_index = dynamic_cast<const HnswIndex*>(tensor.nearest_neighbor_index());
+ ASSERT_TRUE(hnsw_index != nullptr);
+
+ const auto& cfg = hnsw_index->config();
+ EXPECT_EQUAL(8u, cfg.max_links_at_level_0());
+ EXPECT_EQUAL(4u, cfg.max_links_at_hierarchic_levels());
+ EXPECT_EQUAL(20u, cfg.neighbors_to_explore_at_construction());
+ EXPECT_TRUE(cfg.heuristic_select_neighbors());
+}
+
+class DenseTensorAttributeMockIndex : public Fixture {
+public:
+ DenseTensorAttributeMockIndex() : Fixture(vec_2d_spec, true, true, true) {}
+};
+
+TEST_F("setTensor() updates nearest neighbor index", DenseTensorAttributeMockIndex)
+{
+ auto& index = f.mock_index();
+
+ f.set_tensor(1, vec_2d(3, 5));
+ index.expect_add(1, {3, 5});
+ index.expect_empty_remove();
+ index.clear();
+
+ // Replaces previous value.
+ f.set_tensor(1, vec_2d(7, 9));
+ index.expect_remove(1, {3, 5});
+ index.expect_add(1, {7, 9});
+}
+
+TEST_F("clearDoc() updates nearest neighbor index", DenseTensorAttributeMockIndex)
+{
+ auto& index = f.mock_index();
+
+ // Nothing to clear.
+ f.clearTensor(1);
+ index.expect_empty_remove();
+ index.expect_empty_add();
+
+ // Clears previous value.
+ f.set_tensor(1, vec_2d(3, 5));
+ index.clear();
+ f.clearTensor(1);
+ index.expect_remove(1, {3, 5});
+ index.expect_empty_add();
+}
+
+TEST_F("onLoad() updates nearest neighbor index", DenseTensorAttributeMockIndex)
+{
+ f.set_tensor(1, vec_2d(3, 5));
+ f.set_tensor(2, vec_2d(7, 9));
+ f.save();
+ f.load();
+ auto& index = f.mock_index();
+ index.expect_adds({{1, {3, 5}}, {2, {7, 9}}});
+}
+
TEST_MAIN() { TEST_RUN_ALL(); vespalib::unlink("test.dat"); }
diff --git a/searchlib/src/tests/tensor/hnsw_index/hnsw_index_test.cpp b/searchlib/src/tests/tensor/hnsw_index/hnsw_index_test.cpp
index c6246bb8434..cd0d4bcaad0 100644
--- a/searchlib/src/tests/tensor/hnsw_index/hnsw_index_test.cpp
+++ b/searchlib/src/tests/tensor/hnsw_index/hnsw_index_test.cpp
@@ -48,8 +48,7 @@ using HnswIndexUP = std::unique_ptr<HnswIndex>;
class HnswIndexTest : public ::testing::Test {
public:
FloatVectors vectors;
- FloatSqEuclideanDistance distance_func;
- LevelGenerator level_generator;
+ LevelGenerator* level_generator;
HnswIndexUP index;
HnswIndexTest()
@@ -62,11 +61,14 @@ public:
.set(7, {3, 5}).set(8, {0, 3}).set(9, {4, 5});
}
void init(bool heuristic_select_neighbors) {
- index = std::make_unique<HnswIndex>(vectors, distance_func, level_generator,
+ auto generator = std::make_unique<LevelGenerator>();
+ level_generator = generator.get();
+ index = std::make_unique<HnswIndex>(vectors, std::make_unique<FloatSqEuclideanDistance>(),
+ std::move(generator),
HnswIndex::Config(2, 1, 10, heuristic_select_neighbors));
}
void add_document(uint32_t docid, uint32_t max_level = 0) {
- level_generator.level = max_level;
+ level_generator->level = max_level;
index->add_document(docid);
}
void remove_document(uint32_t docid) {
diff --git a/searchlib/src/vespa/searchlib/attribute/configconverter.cpp b/searchlib/src/vespa/searchlib/attribute/configconverter.cpp
index 535e81fc032..10e1a1edb52 100644
--- a/searchlib/src/vespa/searchlib/attribute/configconverter.cpp
+++ b/searchlib/src/vespa/searchlib/attribute/configconverter.cpp
@@ -73,6 +73,10 @@ ConfigConverter::convert(const AttributesConfig::Attribute & cfg)
predicateParams.setBounds(cfg.lowerbound, cfg.upperbound);
predicateParams.setDensePostingListThreshold(cfg.densepostinglistthreshold);
retval.setPredicateParams(predicateParams);
+ if (cfg.index.hnsw.enabled) {
+ retval.set_hnsw_index_params(HnswIndexParams(cfg.index.hnsw.maxlinkspernode,
+ cfg.index.hnsw.neighborstoexploreatinsert));
+ }
if (retval.basicType().type() == BasicType::Type::TENSOR) {
if (!cfg.tensortype.empty()) {
retval.setTensorType(ValueType::from_spec(cfg.tensortype));
diff --git a/searchlib/src/vespa/searchlib/tensor/CMakeLists.txt b/searchlib/src/vespa/searchlib/tensor/CMakeLists.txt
index 9175168248c..09069861ab4 100644
--- a/searchlib/src/vespa/searchlib/tensor/CMakeLists.txt
+++ b/searchlib/src/vespa/searchlib/tensor/CMakeLists.txt
@@ -1,6 +1,7 @@
# Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
vespa_add_library(searchlib_tensor OBJECT
SOURCES
+ default_nearest_neighbor_index_factory.cpp
dense_tensor_attribute.cpp
dense_tensor_attribute_saver.cpp
dense_tensor_store.cpp
diff --git a/searchlib/src/vespa/searchlib/tensor/default_nearest_neighbor_index_factory.cpp b/searchlib/src/vespa/searchlib/tensor/default_nearest_neighbor_index_factory.cpp
new file mode 100644
index 00000000000..68efe6417c0
--- /dev/null
+++ b/searchlib/src/vespa/searchlib/tensor/default_nearest_neighbor_index_factory.cpp
@@ -0,0 +1,51 @@
+// Copyright 2020 Oath Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+
+#include "default_nearest_neighbor_index_factory.h"
+#include "distance_functions.h"
+#include "hnsw_index.h"
+#include "random_level_generator.h"
+#include <vespa/searchcommon/attribute/config.h>
+
+namespace search::tensor {
+
+using vespalib::eval::ValueType;
+
+namespace {
+
+class LevelZeroGenerator : public RandomLevelGenerator {
+ uint32_t max_level() override { return 0; }
+};
+
+DistanceFunction::UP
+make_distance_function(ValueType::CellType cell_type)
+{
+ if (cell_type == ValueType::CellType::FLOAT) {
+ return std::make_unique<SquaredEuclideanDistance<float>>();
+ } else {
+ return std::make_unique<SquaredEuclideanDistance<double>>();
+ }
+}
+
+RandomLevelGenerator::UP
+make_random_level_generator()
+{
+ // TODO: Make generator that results in hierarchical graph.
+ return std::make_unique<LevelZeroGenerator>();
+}
+
+}
+
+std::unique_ptr<NearestNeighborIndex>
+DefaultNearestNeighborIndexFactory::make(const DocVectorAccess& vectors,
+ vespalib::eval::ValueType::CellType cell_type,
+ const search::attribute::HnswIndexParams& params) const
+{
+ HnswIndex::Config cfg(params.max_links_per_node() * 2,
+ params.max_links_per_node(),
+ params.neighbors_to_explore_at_insert(),
+ true);
+ return std::make_unique<HnswIndex>(vectors, make_distance_function(cell_type), make_random_level_generator(), cfg);
+}
+
+}
+
diff --git a/searchlib/src/vespa/searchlib/tensor/default_nearest_neighbor_index_factory.h b/searchlib/src/vespa/searchlib/tensor/default_nearest_neighbor_index_factory.h
new file mode 100644
index 00000000000..ea784efdb51
--- /dev/null
+++ b/searchlib/src/vespa/searchlib/tensor/default_nearest_neighbor_index_factory.h
@@ -0,0 +1,19 @@
+// Copyright 2020 Oath Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+
+#pragma once
+
+#include "nearest_neighbor_index_factory.h"
+
+namespace search::tensor {
+
+/**
+ * Factory that instantiates the production hnsw index.
+ */
+class DefaultNearestNeighborIndexFactory : public NearestNeighborIndexFactory {
+public:
+ std::unique_ptr<NearestNeighborIndex> make(const DocVectorAccess& vectors,
+ vespalib::eval::ValueType::CellType cell_type,
+ const search::attribute::HnswIndexParams& params) const override;
+};
+
+}
diff --git a/searchlib/src/vespa/searchlib/tensor/dense_tensor_attribute.cpp b/searchlib/src/vespa/searchlib/tensor/dense_tensor_attribute.cpp
index a2b9f136ed9..171340e07f1 100644
--- a/searchlib/src/vespa/searchlib/tensor/dense_tensor_attribute.cpp
+++ b/searchlib/src/vespa/searchlib/tensor/dense_tensor_attribute.cpp
@@ -2,6 +2,7 @@
#include "dense_tensor_attribute.h"
#include "dense_tensor_attribute_saver.h"
+#include "nearest_neighbor_index.h"
#include "tensor_attribute.hpp"
#include <vespa/eval/tensor/tensor.h>
#include <vespa/eval/tensor/dense/mutable_dense_tensor_view.h>
@@ -55,11 +56,23 @@ TensorReader::is_present() {
}
-DenseTensorAttribute::DenseTensorAttribute(vespalib::stringref baseFileName,
- const Config &cfg)
+void
+DenseTensorAttribute::consider_remove_from_index(DocId docid)
+{
+ if (_index && _refVector[docid].valid()) {
+ _index->remove_document(docid);
+ }
+}
+
+DenseTensorAttribute::DenseTensorAttribute(vespalib::stringref baseFileName, const Config& cfg,
+ const NearestNeighborIndexFactory& index_factory)
: TensorAttribute(baseFileName, cfg, _denseTensorStore),
- _denseTensorStore(cfg.tensorType())
+ _denseTensorStore(cfg.tensorType()),
+ _index()
{
+ if (cfg.hnsw_index_params().has_value()) {
+ _index = index_factory.make(*this, cfg.tensorType().cell_type(), cfg.hnsw_index_params().value());
+ }
}
@@ -69,12 +82,23 @@ DenseTensorAttribute::~DenseTensorAttribute()
_tensorStore.clearHoldLists();
}
+uint32_t
+DenseTensorAttribute::clearDoc(DocId docId)
+{
+ consider_remove_from_index(docId);
+ return TensorAttribute::clearDoc(docId);
+}
+
void
DenseTensorAttribute::setTensor(DocId docId, const Tensor &tensor)
{
checkTensorType(tensor);
+ consider_remove_from_index(docId);
EntryRef ref = _denseTensorStore.setTensor(tensor);
setTensorRef(docId, ref);
+ if (_index) {
+ _index->add_document(docId);
+ }
}
@@ -120,6 +144,11 @@ DenseTensorAttribute::onLoad()
auto raw = _denseTensorStore.allocRawBuffer();
tensorReader.readTensor(raw.data, _denseTensorStore.getBufSize());
_refVector.push_back(raw.ref);
+ if (_index) {
+ // This ensures that get_vector() (via getTensor()) is able to find the newly added tensor.
+ setCommittedDocIdLimit(lid + 1);
+ _index->add_document(lid);
+ }
} else {
_refVector.push_back(EntryRef());
}
@@ -154,4 +183,12 @@ DenseTensorAttribute::getVersion() const
return DENSE_TENSOR_ATTRIBUTE_VERSION;
}
+vespalib::tensor::TypedCells
+DenseTensorAttribute::get_vector(uint32_t docid) const
+{
+ MutableDenseTensorView tensor_view(_denseTensorStore.type());
+ getTensor(docid, tensor_view);
+ return tensor_view.cellsRef();
+}
+
}
diff --git a/searchlib/src/vespa/searchlib/tensor/dense_tensor_attribute.h b/searchlib/src/vespa/searchlib/tensor/dense_tensor_attribute.h
index 593741cef39..f9a8a81b56b 100644
--- a/searchlib/src/vespa/searchlib/tensor/dense_tensor_attribute.h
+++ b/searchlib/src/vespa/searchlib/tensor/dense_tensor_attribute.h
@@ -2,35 +2,47 @@
#pragma once
-#include "tensor_attribute.h"
+#include "default_nearest_neighbor_index_factory.h"
#include "dense_tensor_store.h"
+#include "doc_vector_access.h"
+#include "tensor_attribute.h"
+#include <memory>
-namespace vespalib { namespace tensor { class MutableDenseTensorView; }}
+namespace vespalib::tensor { class MutableDenseTensorView; }
-namespace search {
+namespace search::tensor {
-namespace tensor {
+class NearestNeighborIndex;
/**
* Attribute vector class used to store dense tensors for all
* documents in memory.
*/
-class DenseTensorAttribute : public TensorAttribute
-{
+class DenseTensorAttribute : public TensorAttribute, public DocVectorAccess {
+private:
DenseTensorStore _denseTensorStore;
+ std::unique_ptr<NearestNeighborIndex> _index;
+
+ void consider_remove_from_index(DocId docid);
+
public:
- DenseTensorAttribute(vespalib::stringref baseFileName, const Config &cfg);
+ DenseTensorAttribute(vespalib::stringref baseFileName, const Config& cfg,
+ const NearestNeighborIndexFactory& index_factory = DefaultNearestNeighborIndexFactory());
virtual ~DenseTensorAttribute();
- virtual void setTensor(DocId docId, const Tensor &tensor) override;
- virtual std::unique_ptr<Tensor> getTensor(DocId docId) const override;
- virtual void getTensor(DocId docId, vespalib::tensor::MutableDenseTensorView &tensor) const override;
- virtual bool onLoad() override;
- virtual std::unique_ptr<AttributeSaver> onInitSave(vespalib::stringref fileName) override;
- virtual void compactWorst() override;
- virtual uint32_t getVersion() const override;
+ // Implements TensorAttribute
+ uint32_t clearDoc(DocId docId) override;
+ void setTensor(DocId docId, const Tensor &tensor) override;
+ std::unique_ptr<Tensor> getTensor(DocId docId) const override;
+ void getTensor(DocId docId, vespalib::tensor::MutableDenseTensorView &tensor) const override;
+ bool onLoad() override;
+ std::unique_ptr<AttributeSaver> onInitSave(vespalib::stringref fileName) override;
+ void compactWorst() override;
+ uint32_t getVersion() const override;
+
+ // Implements DocVectorAccess
+ vespalib::tensor::TypedCells get_vector(uint32_t docid) const override;
+
+ const NearestNeighborIndex* nearest_neighbor_index() const { return _index.get(); }
};
-
-} // namespace search::tensor
-
-} // namespace search
+}
diff --git a/searchlib/src/vespa/searchlib/tensor/distance_function.h b/searchlib/src/vespa/searchlib/tensor/distance_function.h
index 8dfb77ddccb..b682824c805 100644
--- a/searchlib/src/vespa/searchlib/tensor/distance_function.h
+++ b/searchlib/src/vespa/searchlib/tensor/distance_function.h
@@ -2,6 +2,8 @@
#pragma once
+#include <memory>
+
namespace vespalib::tensor { struct TypedCells; }
namespace search::tensor {
@@ -14,6 +16,7 @@ namespace search::tensor {
*/
class DistanceFunction {
public:
+ using UP = std::unique_ptr<DistanceFunction>;
virtual ~DistanceFunction() {}
virtual double calc(const vespalib::tensor::TypedCells& lhs, const vespalib::tensor::TypedCells& rhs) const = 0;
};
diff --git a/searchlib/src/vespa/searchlib/tensor/distance_functions.h b/searchlib/src/vespa/searchlib/tensor/distance_functions.h
index 1e8727e92aa..494d1a859b6 100644
--- a/searchlib/src/vespa/searchlib/tensor/distance_functions.h
+++ b/searchlib/src/vespa/searchlib/tensor/distance_functions.h
@@ -3,6 +3,7 @@
#pragma once
#include "distance_function.h"
+#include <vespa/eval/tensor/dense/typed_cells.h>
namespace search::tensor {
diff --git a/searchlib/src/vespa/searchlib/tensor/hnsw_index.cpp b/searchlib/src/vespa/searchlib/tensor/hnsw_index.cpp
index be53b758841..860686f3c6a 100644
--- a/searchlib/src/vespa/searchlib/tensor/hnsw_index.cpp
+++ b/searchlib/src/vespa/searchlib/tensor/hnsw_index.cpp
@@ -44,7 +44,7 @@ HnswIndex::max_links_for_level(uint32_t level) const
uint32_t
HnswIndex::make_node_for_document(uint32_t docid)
{
- uint32_t max_level = _level_generator.max_level();
+ uint32_t max_level = _level_generator->max_level();
// TODO: Add capping on num_levels
uint32_t num_levels = max_level + 1;
// Note: The level array instance lives as long as the document is present in the index.
@@ -170,7 +170,7 @@ double
HnswIndex::calc_distance(const TypedCells& lhs, uint32_t rhs_docid) const
{
auto rhs = get_vector(rhs_docid);
- return _distance_func.calc(lhs, rhs);
+ return _distance_func->calc(lhs, rhs);
}
HnswCandidate
@@ -227,11 +227,11 @@ HnswIndex::search_layer(const TypedCells& input, uint32_t neighbors_to_find, Fur
}
}
-HnswIndex::HnswIndex(const DocVectorAccess& vectors, const DistanceFunction& distance_func,
- RandomLevelGenerator& level_generator, const Config& cfg)
+HnswIndex::HnswIndex(const DocVectorAccess& vectors, DistanceFunction::UP distance_func,
+ RandomLevelGenerator::UP level_generator, const Config& cfg)
: _vectors(vectors),
- _distance_func(distance_func),
- _level_generator(level_generator),
+ _distance_func(std::move(distance_func)),
+ _level_generator(std::move(level_generator)),
_cfg(cfg),
_node_refs(),
_nodes(make_default_node_store_config()),
diff --git a/searchlib/src/vespa/searchlib/tensor/hnsw_index.h b/searchlib/src/vespa/searchlib/tensor/hnsw_index.h
index 814148072ca..66d6a6d25c2 100644
--- a/searchlib/src/vespa/searchlib/tensor/hnsw_index.h
+++ b/searchlib/src/vespa/searchlib/tensor/hnsw_index.h
@@ -2,10 +2,12 @@
#pragma once
+#include "distance_function.h"
#include "doc_vector_access.h"
#include "hnsw_index_utils.h"
#include "hnsw_node.h"
#include "nearest_neighbor_index.h"
+#include "random_level_generator.h"
#include <vespa/eval/tensor/dense/typed_cells.h>
#include <vespa/searchlib/common/bitvector.h>
#include <vespa/vespalib/datastore/array_store.h>
@@ -15,9 +17,6 @@
namespace search::tensor {
-class DistanceFunction;
-class RandomLevelGenerator;
-
/**
* Implementation of a hierarchical navigable small world graph (HNSW)
* that is used for approximate K-nearest neighbor search.
@@ -82,8 +81,8 @@ protected:
using TypedCells = vespalib::tensor::TypedCells;
const DocVectorAccess& _vectors;
- const DistanceFunction& _distance_func;
- RandomLevelGenerator& _level_generator;
+ DistanceFunction::UP _distance_func;
+ RandomLevelGenerator::UP _level_generator;
Config _cfg;
NodeRefVector _node_refs;
NodeStore _nodes;
@@ -128,10 +127,12 @@ protected:
void search_layer(const TypedCells& input, uint32_t neighbors_to_find, FurthestPriQ& found_neighbors, uint32_t level);
public:
- HnswIndex(const DocVectorAccess& vectors, const DistanceFunction& distance_func,
- RandomLevelGenerator& level_generator, const Config& cfg);
+ HnswIndex(const DocVectorAccess& vectors, DistanceFunction::UP distance_func,
+ RandomLevelGenerator::UP level_generator, const Config& cfg);
~HnswIndex() override;
+ const Config& config() const { return _cfg; }
+
void add_document(uint32_t docid) override;
void remove_document(uint32_t docid) override;
std::vector<uint32_t> find_top_k(uint32_t k, TypedCells vector, uint32_t explore_k) override;
diff --git a/searchlib/src/vespa/searchlib/tensor/nearest_neighbor_index_factory.h b/searchlib/src/vespa/searchlib/tensor/nearest_neighbor_index_factory.h
new file mode 100644
index 00000000000..c09403df5e0
--- /dev/null
+++ b/searchlib/src/vespa/searchlib/tensor/nearest_neighbor_index_factory.h
@@ -0,0 +1,26 @@
+// Copyright 2020 Oath Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+
+#pragma once
+
+#include <vespa/eval/eval/value_type.h>
+#include <memory>
+
+namespace search::attribute { class HnswIndexParams; }
+
+namespace search::tensor {
+
+class DocVectorAccess;
+class NearestNeighborIndex;
+
+/**
+ * Factory interface used to instantiate an index used for (approximate) nearest neighbor search.
+ */
+class NearestNeighborIndexFactory {
+public:
+ virtual ~NearestNeighborIndexFactory() {}
+ virtual std::unique_ptr<NearestNeighborIndex> make(const DocVectorAccess& vectors,
+ vespalib::eval::ValueType::CellType cell_type,
+ const search::attribute::HnswIndexParams& params) const = 0;
+};
+
+}
diff --git a/searchlib/src/vespa/searchlib/tensor/random_level_generator.h b/searchlib/src/vespa/searchlib/tensor/random_level_generator.h
index 0fcac977d9d..0f4c7c34445 100644
--- a/searchlib/src/vespa/searchlib/tensor/random_level_generator.h
+++ b/searchlib/src/vespa/searchlib/tensor/random_level_generator.h
@@ -2,6 +2,8 @@
#pragma once
+#include <memory>
+
namespace search::tensor {
/**
@@ -9,6 +11,7 @@ namespace search::tensor {
*/
class RandomLevelGenerator {
public:
+ using UP = std::unique_ptr<RandomLevelGenerator>;
virtual ~RandomLevelGenerator() {}
virtual uint32_t max_level() = 0;
};