summaryrefslogtreecommitdiffstats
path: root/searchlib
diff options
context:
space:
mode:
authorArne H Juul <arnej27959@users.noreply.github.com>2023-04-20 09:09:33 +0200
committerGitHub <noreply@github.com>2023-04-20 09:09:33 +0200
commit3cde18008cff2d1c812ec86141f538b56d5248ab (patch)
treeeaca23b97d05abf775dfaf68b495cae70107c450 /searchlib
parent4bf83d5e87a8896ce3b6a14fb0889a2891053bf1 (diff)
parent732e4c4be8bbc5a43e3adae5db222301e630bd8c (diff)
Merge pull request #26783 from vespa-engine/arnej/refactor-with-bound-distance
add mimimal version of BoundDistanceFunction
Diffstat (limited to 'searchlib')
-rw-r--r--searchlib/src/tests/attribute/tensorattribute/tensorattribute_test.cpp17
-rw-r--r--searchlib/src/tests/queryeval/nearest_neighbor/nearest_neighbor_test.cpp12
-rw-r--r--searchlib/src/tests/tensor/hnsw_index/hnsw_index_test.cpp25
-rw-r--r--searchlib/src/tests/tensor/hnsw_index/stress_hnsw_mt.cpp8
-rw-r--r--searchlib/src/vespa/searchlib/queryeval/nearest_neighbor_blueprint.cpp6
-rw-r--r--searchlib/src/vespa/searchlib/queryeval/nns_index_iterator.cpp6
-rw-r--r--searchlib/src/vespa/searchlib/queryeval/nns_index_iterator.h2
-rw-r--r--searchlib/src/vespa/searchlib/tensor/CMakeLists.txt1
-rw-r--r--searchlib/src/vespa/searchlib/tensor/bound_distance_function.cpp3
-rw-r--r--searchlib/src/vespa/searchlib/tensor/bound_distance_function.h44
-rw-r--r--searchlib/src/vespa/searchlib/tensor/default_nearest_neighbor_index_factory.cpp4
-rw-r--r--searchlib/src/vespa/searchlib/tensor/distance_calculator.cpp29
-rw-r--r--searchlib/src/vespa/searchlib/tensor/distance_calculator.h21
-rw-r--r--searchlib/src/vespa/searchlib/tensor/distance_function.h19
-rw-r--r--searchlib/src/vespa/searchlib/tensor/distance_function_factory.cpp56
-rw-r--r--searchlib/src/vespa/searchlib/tensor/distance_function_factory.h19
-rw-r--r--searchlib/src/vespa/searchlib/tensor/hnsw_index.cpp90
-rw-r--r--searchlib/src/vespa/searchlib/tensor/hnsw_index.h40
-rw-r--r--searchlib/src/vespa/searchlib/tensor/i_tensor_attribute.h2
-rw-r--r--searchlib/src/vespa/searchlib/tensor/imported_tensor_attribute_vector_read_guard.h3
-rw-r--r--searchlib/src/vespa/searchlib/tensor/nearest_neighbor_index.h7
-rw-r--r--searchlib/src/vespa/searchlib/tensor/tensor_attribute.cpp8
-rw-r--r--searchlib/src/vespa/searchlib/tensor/tensor_attribute.h2
-rw-r--r--searchlib/src/vespa/searchlib/tensor/tensor_ext_attribute.cpp1
-rw-r--r--searchlib/src/vespa/searchlib/tensor/tensor_ext_attribute.h7
25 files changed, 311 insertions, 121 deletions
diff --git a/searchlib/src/tests/attribute/tensorattribute/tensorattribute_test.cpp b/searchlib/src/tests/attribute/tensorattribute/tensorattribute_test.cpp
index 28c50891225..e3c9e05073e 100644
--- a/searchlib/src/tests/attribute/tensorattribute/tensorattribute_test.cpp
+++ b/searchlib/src/tests/attribute/tensorattribute/tensorattribute_test.cpp
@@ -294,30 +294,33 @@ public:
std::unique_ptr<NearestNeighborIndexLoader> make_loader(FastOS_FileInterface& file) override {
return std::make_unique<MockIndexLoader>(_index_value, file);
}
- std::vector<Neighbor> find_top_k(uint32_t k, vespalib::eval::TypedCells vector, uint32_t explore_k,
+ std::vector<Neighbor> find_top_k(uint32_t k,
+ const search::tensor::BoundDistanceFunction &df,
+ uint32_t explore_k,
double distance_threshold) const override
{
(void) k;
- (void) vector;
+ (void) df;
(void) explore_k;
(void) distance_threshold;
return std::vector<Neighbor>();
}
- std::vector<Neighbor> find_top_k_with_filter(uint32_t k, vespalib::eval::TypedCells vector,
+ std::vector<Neighbor> find_top_k_with_filter(uint32_t k,
+ const search::tensor::BoundDistanceFunction &df,
const GlobalFilter& filter, uint32_t explore_k,
double distance_threshold) const override
{
(void) k;
- (void) vector;
+ (void) df;
(void) explore_k;
(void) filter;
(void) distance_threshold;
return std::vector<Neighbor>();
}
- const search::tensor::DistanceFunction *distance_function() const override {
- static search::tensor::SquaredEuclideanDistance my_dist_fun(vespalib::eval::CellType::DOUBLE);
- return &my_dist_fun;
+ search::tensor::DistanceFunctionFactory &distance_function_factory() const override {
+ static search::tensor::DistanceFunctionFactory::UP my_dist_fun = search::tensor::make_distance_function_factory(search::attribute::DistanceMetric::Euclidean, vespalib::eval::CellType::DOUBLE);
+ return *my_dist_fun;
}
};
diff --git a/searchlib/src/tests/queryeval/nearest_neighbor/nearest_neighbor_test.cpp b/searchlib/src/tests/queryeval/nearest_neighbor/nearest_neighbor_test.cpp
index 2801bf90080..fd07529795a 100644
--- a/searchlib/src/tests/queryeval/nearest_neighbor/nearest_neighbor_test.cpp
+++ b/searchlib/src/tests/queryeval/nearest_neighbor/nearest_neighbor_test.cpp
@@ -134,7 +134,9 @@ SimpleResult find_matches(Fixture &env, const Value &qtv, double threshold = std
auto md = MatchData::makeTestInstance(2, 2);
auto &tfmd = *(md->resolveTermField(0));
auto &attr = *(env._attr);
- DistanceCalculator dist_calc(attr, qtv, env.dist_fun());
+
+ auto dff = search::tensor::make_distance_function_factory(DistanceMetric::Euclidean, qtv.cells().type);
+ DistanceCalculator dist_calc(attr, dff->for_query_vector(qtv.cells()));
NearestNeighborDistanceHeap dh(2);
dh.set_distance_threshold(env.dist_fun().convert_threshold(threshold));
const GlobalFilter &filter = *env._global_filter;
@@ -260,7 +262,8 @@ std::vector<feature_t> get_rawscores(Fixture &env, const Value &qtv) {
auto md = MatchData::makeTestInstance(2, 2);
auto &tfmd = *(md->resolveTermField(0));
auto &attr = *(env._attr);
- DistanceCalculator dist_calc(attr, qtv, env.dist_fun());
+ auto dff = search::tensor::make_distance_function_factory(DistanceMetric::Euclidean, qtv.cells().type);
+ DistanceCalculator dist_calc(attr, dff->for_query_vector(qtv.cells()));
NearestNeighborDistanceHeap dh(2);
auto dummy_filter = GlobalFilter::create();
auto search = NearestNeighborIterator::create(strict, tfmd, dist_calc, dh, *dummy_filter);
@@ -333,7 +336,10 @@ TEST(NnsIndexIteratorTest, require_that_iterator_works_as_expected) {
std::vector<NnsIndexIterator::Hit> hits{{2,4.0}, {3,9.0}, {5,1.0}, {8,16.0}, {9,36.0}};
auto md = MatchData::makeTestInstance(2, 2);
auto &tfmd = *(md->resolveTermField(0));
- auto search = NnsIndexIterator::create(tfmd, hits, *euclid_d);
+ auto dff = search::tensor::make_distance_function_factory(DistanceMetric::Euclidean, CellType::DOUBLE);
+ vespalib::eval::TypedCells dummy;
+ auto df = dff->for_query_vector(dummy);
+ auto search = NnsIndexIterator::create(tfmd, hits, *df);
search->initFullRange();
expect_not_match(*search, 1, 2);
expect_match(*search, 2);
diff --git a/searchlib/src/tests/tensor/hnsw_index/hnsw_index_test.cpp b/searchlib/src/tests/tensor/hnsw_index/hnsw_index_test.cpp
index d9230849699..9f6216f5867 100644
--- a/searchlib/src/tests/tensor/hnsw_index/hnsw_index_test.cpp
+++ b/searchlib/src/tests/tensor/hnsw_index/hnsw_index_test.cpp
@@ -105,10 +105,16 @@ public:
~HnswIndexTest() {}
+ auto dff() {
+ return search::tensor::make_distance_function_factory(
+ search::attribute::DistanceMetric::Euclidean,
+ vespalib::eval::CellType::FLOAT);
+ }
+
void init(bool heuristic_select_neighbors) {
auto generator = std::make_unique<LevelGenerator>();
level_generator = generator.get();
- index = std::make_unique<IndexType>(vectors, std::make_unique<SquaredEuclideanDistance>(vespalib::eval::CellType::FLOAT),
+ index = std::make_unique<IndexType>(vectors, dff(),
std::move(generator),
HnswIndexConfig(5, 2, 10, 0, heuristic_select_neighbors));
}
@@ -165,9 +171,10 @@ public:
uint32_t explore_k = 100;
vespalib::ArrayRef qv_ref(qv);
vespalib::eval::TypedCells qv_cells(qv_ref);
+ auto df = index->distance_function_factory().for_query_vector(qv_cells);
auto got_by_docid = (global_filter->is_active()) ?
- index->find_top_k_with_filter(k, qv_cells, *global_filter, explore_k, 10000.0) :
- index->find_top_k(k, qv_cells, explore_k, 10000.0);
+ index->find_top_k_with_filter(k, *df, *global_filter, explore_k, 10000.0) :
+ index->find_top_k(k, *df, explore_k, 10000.0);
std::vector<uint32_t> act;
act.reserve(got_by_docid.size());
for (auto& hit : got_by_docid) {
@@ -178,7 +185,8 @@ public:
void expect_top_3(uint32_t docid, std::vector<uint32_t> exp_hits) {
uint32_t k = 3;
auto qv = vectors.get_vector(docid, 0);
- auto rv = index->top_k_candidates(qv, k, global_filter->ptr_if_active()).peek();
+ auto df = index->distance_function_factory().for_query_vector(qv);
+ auto rv = index->top_k_candidates(*df, k, global_filter->ptr_if_active()).peek();
std::sort(rv.begin(), rv.end(), LesserDistance());
size_t idx = 0;
for (const auto & hit : rv) {
@@ -189,7 +197,7 @@ public:
if (exp_hits.size() == k) {
std::vector<uint32_t> expected_by_docid = exp_hits;
std::sort(expected_by_docid.begin(), expected_by_docid.end());
- auto got_by_docid = index->find_top_k(k, qv, k, 100100.25);
+ auto got_by_docid = index->find_top_k(k, *df, k, 100100.25);
for (idx = 0; idx < k; ++idx) {
EXPECT_EQ(expected_by_docid[idx], got_by_docid[idx].docid);
}
@@ -198,15 +206,16 @@ public:
}
void check_with_distance_threshold(uint32_t docid) {
auto qv = vectors.get_vector(docid, 0);
+ auto df = index->distance_function_factory().for_query_vector(qv);
uint32_t k = 3;
- auto rv = index->top_k_candidates(qv, k, global_filter->ptr_if_active()).peek();
+ auto rv = index->top_k_candidates(*df, k, global_filter->ptr_if_active()).peek();
std::sort(rv.begin(), rv.end(), LesserDistance());
EXPECT_EQ(rv.size(), 3);
EXPECT_LE(rv[0].distance, rv[1].distance);
double thr = (rv[0].distance + rv[1].distance) * 0.5;
auto got_by_docid = (global_filter->is_active())
- ? index->find_top_k_with_filter(k, qv, *global_filter, k, thr)
- : index->find_top_k(k, qv, k, thr);
+ ? index->find_top_k_with_filter(k, *df, *global_filter, k, thr)
+ : index->find_top_k(k, *df, k, thr);
EXPECT_EQ(got_by_docid.size(), 1);
EXPECT_EQ(got_by_docid[0].docid, index->get_docid(rv[0].nodeid));
for (const auto & hit : got_by_docid) {
diff --git a/searchlib/src/tests/tensor/hnsw_index/stress_hnsw_mt.cpp b/searchlib/src/tests/tensor/hnsw_index/stress_hnsw_mt.cpp
index ecf310798af..0dcd77ec392 100644
--- a/searchlib/src/tests/tensor/hnsw_index/stress_hnsw_mt.cpp
+++ b/searchlib/src/tests/tensor/hnsw_index/stress_hnsw_mt.cpp
@@ -261,9 +261,15 @@ public:
~Stressor() {}
+ auto dff() {
+ return search::tensor::make_distance_function_factory(
+ search::attribute::DistanceMetric::Euclidean,
+ vespalib::eval::CellType::FLOAT);
+ }
+
void init() {
uint32_t m = 16;
- index = std::make_unique<IndexType>(vectors, std::make_unique<FloatSqEuclideanDistance>(),
+ index = std::make_unique<IndexType>(vectors, dff(),
std::make_unique<InvLogLevelGenerator>(m),
HnswIndexConfig(2*m, m, 200, 10, true));
}
diff --git a/searchlib/src/vespa/searchlib/queryeval/nearest_neighbor_blueprint.cpp b/searchlib/src/vespa/searchlib/queryeval/nearest_neighbor_blueprint.cpp
index 7fdf5230325..7c307a1e35f 100644
--- a/searchlib/src/vespa/searchlib/queryeval/nearest_neighbor_blueprint.cpp
+++ b/searchlib/src/vespa/searchlib/queryeval/nearest_neighbor_blueprint.cpp
@@ -106,13 +106,13 @@ NearestNeighborBlueprint::set_global_filter(const GlobalFilter &global_filter, d
void
NearestNeighborBlueprint::perform_top_k(const search::tensor::NearestNeighborIndex* nns_index)
{
- auto lhs = _query_tensor.cells();
uint32_t k = _adjusted_target_hits;
+ const auto &df = _distance_calc->function();
if (_global_filter->is_active()) {
- _found_hits = nns_index->find_top_k_with_filter(k, lhs, *_global_filter, k + _explore_additional_hits, _distance_threshold);
+ _found_hits = nns_index->find_top_k_with_filter(k, df, *_global_filter, k + _explore_additional_hits, _distance_threshold);
_algorithm = Algorithm::INDEX_TOP_K_WITH_FILTER;
} else {
- _found_hits = nns_index->find_top_k(k, lhs, k + _explore_additional_hits, _distance_threshold);
+ _found_hits = nns_index->find_top_k(k, df, k + _explore_additional_hits, _distance_threshold);
_algorithm = Algorithm::INDEX_TOP_K;
}
}
diff --git a/searchlib/src/vespa/searchlib/queryeval/nns_index_iterator.cpp b/searchlib/src/vespa/searchlib/queryeval/nns_index_iterator.cpp
index 95264a79431..5ec4357ca24 100644
--- a/searchlib/src/vespa/searchlib/queryeval/nns_index_iterator.cpp
+++ b/searchlib/src/vespa/searchlib/queryeval/nns_index_iterator.cpp
@@ -18,13 +18,13 @@ class NeighborVectorIterator : public NnsIndexIterator
private:
fef::TermFieldMatchData &_tfmd;
const std::vector<Neighbor> &_hits;
- const search::tensor::DistanceFunction &_dist_fun;
+ const search::tensor::BoundDistanceFunction &_dist_fun;
uint32_t _idx;
double _last_abstract_dist;
public:
NeighborVectorIterator(fef::TermFieldMatchData &tfmd,
const std::vector<Neighbor> &hits,
- const search::tensor::DistanceFunction &dist_fun)
+ const search::tensor::BoundDistanceFunction &dist_fun)
: _tfmd(tfmd),
_hits(hits),
_dist_fun(dist_fun),
@@ -65,7 +65,7 @@ std::unique_ptr<NnsIndexIterator>
NnsIndexIterator::create(
fef::TermFieldMatchData &tfmd,
const std::vector<Neighbor> &hits,
- const search::tensor::DistanceFunction &dist_fun)
+ const search::tensor::BoundDistanceFunction &dist_fun)
{
return std::make_unique<NeighborVectorIterator>(tfmd, hits, dist_fun);
}
diff --git a/searchlib/src/vespa/searchlib/queryeval/nns_index_iterator.h b/searchlib/src/vespa/searchlib/queryeval/nns_index_iterator.h
index 031a603de49..84ff0f04813 100644
--- a/searchlib/src/vespa/searchlib/queryeval/nns_index_iterator.h
+++ b/searchlib/src/vespa/searchlib/queryeval/nns_index_iterator.h
@@ -16,7 +16,7 @@ public:
static std::unique_ptr<NnsIndexIterator> create(
fef::TermFieldMatchData &tfmd,
const std::vector<Hit> &hits,
- const search::tensor::DistanceFunction &dist_fun);
+ const search::tensor::BoundDistanceFunction &dist_fun);
};
} // namespace
diff --git a/searchlib/src/vespa/searchlib/tensor/CMakeLists.txt b/searchlib/src/vespa/searchlib/tensor/CMakeLists.txt
index 313863d8dcb..090042e5b83 100644
--- a/searchlib/src/vespa/searchlib/tensor/CMakeLists.txt
+++ b/searchlib/src/vespa/searchlib/tensor/CMakeLists.txt
@@ -3,6 +3,7 @@ vespa_add_library(searchlib_tensor OBJECT
SOURCES
angular_distance.cpp
bitvector_visited_tracker.cpp
+ bound_distance_function.cpp
default_nearest_neighbor_index_factory.cpp
dense_tensor_attribute.cpp
dense_tensor_store.cpp
diff --git a/searchlib/src/vespa/searchlib/tensor/bound_distance_function.cpp b/searchlib/src/vespa/searchlib/tensor/bound_distance_function.cpp
new file mode 100644
index 00000000000..33b94e5218c
--- /dev/null
+++ b/searchlib/src/vespa/searchlib/tensor/bound_distance_function.cpp
@@ -0,0 +1,3 @@
+// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+
+#include "bound_distance_function.h"
diff --git a/searchlib/src/vespa/searchlib/tensor/bound_distance_function.h b/searchlib/src/vespa/searchlib/tensor/bound_distance_function.h
new file mode 100644
index 00000000000..17e9e49cada
--- /dev/null
+++ b/searchlib/src/vespa/searchlib/tensor/bound_distance_function.h
@@ -0,0 +1,44 @@
+// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+
+#pragma once
+
+#include <memory>
+#include <vespa/eval/eval/cell_type.h>
+#include <vespa/eval/eval/typed_cells.h>
+#include <vespa/vespalib/util/array.h>
+#include <vespa/vespalib/util/arrayref.h>
+#include "distance_function.h"
+
+namespace vespalib::eval { struct TypedCells; }
+
+namespace search::tensor {
+
+/**
+ * Interface used to calculate the distance from a prebound n-dimensional vector.
+ *
+ * The actual implementation must know which type the vectors are.
+ */
+class BoundDistanceFunction : public DistanceConverter {
+private:
+ vespalib::eval::CellType _expect_cell_type;
+public:
+ using UP = std::unique_ptr<BoundDistanceFunction>;
+
+ BoundDistanceFunction(vespalib::eval::CellType expected) : _expect_cell_type(expected) {}
+
+ virtual ~BoundDistanceFunction() = default;
+
+ // input vectors will be converted to this cell type:
+ vespalib::eval::CellType expected_cell_type() const {
+ return _expect_cell_type;
+ }
+
+ // calculate internal distance (comparable)
+ virtual double calc(const vespalib::eval::TypedCells& rhs) const = 0;
+
+ // calculate internal distance, early return allowed if > limit
+ virtual double calc_with_limit(const vespalib::eval::TypedCells& rhs,
+ double limit) const = 0;
+};
+
+}
diff --git a/searchlib/src/vespa/searchlib/tensor/default_nearest_neighbor_index_factory.cpp b/searchlib/src/vespa/searchlib/tensor/default_nearest_neighbor_index_factory.cpp
index 4f6f8ac5c87..77c912dc690 100644
--- a/searchlib/src/vespa/searchlib/tensor/default_nearest_neighbor_index_factory.cpp
+++ b/searchlib/src/vespa/searchlib/tensor/default_nearest_neighbor_index_factory.cpp
@@ -41,12 +41,12 @@ DefaultNearestNeighborIndexFactory::make(const DocVectorAccess& vectors,
true);
if (multi_vector_index) {
return std::make_unique<HnswIndex<HnswIndexType::MULTI>>(vectors,
- make_distance_function(params.distance_metric(), cell_type),
+ make_distance_function_factory(params.distance_metric(), cell_type),
make_random_level_generator(m),
cfg);
} else {
return std::make_unique<HnswIndex<HnswIndexType::SINGLE>>(vectors,
- make_distance_function(params.distance_metric(), cell_type),
+ make_distance_function_factory(params.distance_metric(), cell_type),
make_random_level_generator(m),
cfg);
}
diff --git a/searchlib/src/vespa/searchlib/tensor/distance_calculator.cpp b/searchlib/src/vespa/searchlib/tensor/distance_calculator.cpp
index b669b5ffea6..8da777d97eb 100644
--- a/searchlib/src/vespa/searchlib/tensor/distance_calculator.cpp
+++ b/searchlib/src/vespa/searchlib/tensor/distance_calculator.cpp
@@ -47,8 +47,6 @@ struct ConvertCellsSelector
}
};
-
-
}
namespace search::tensor {
@@ -58,36 +56,27 @@ DistanceCalculator::DistanceCalculator(const tensor::ITensorAttribute& attr_tens
: _attr_tensor(attr_tensor),
_query_tensor_uptr(),
_query_tensor(&query_tensor_in),
- _query_tensor_cells(),
- _dist_fun_uptr(make_distance_function(_attr_tensor.distance_metric(),
- _attr_tensor.getTensorType().cell_type())),
- _dist_fun(_dist_fun_uptr.get())
+ _dist_fun()
{
- assert(_dist_fun);
- auto nns_index = _attr_tensor.nearest_neighbor_index();
- if (nns_index) {
- _dist_fun = nns_index->distance_function();
- assert(_dist_fun);
- }
+ auto * nns_index = _attr_tensor.nearest_neighbor_index();
+ auto & dff = nns_index ? nns_index->distance_function_factory() : attr_tensor.distance_function_factory();
auto query_ct = _query_tensor->cells().type;
- CellType required_ct = _dist_fun->expected_cell_type();
+ CellType required_ct = dff.expected_cell_type;
if (query_ct != required_ct) {
ConvertCellsSelector converter;
_query_tensor_uptr = converter(query_ct, required_ct, *_query_tensor);
_query_tensor = _query_tensor_uptr.get();
}
- _query_tensor_cells = _query_tensor->cells();
+ _dist_fun = dff.for_query_vector(_query_tensor->cells());
+ assert(_dist_fun);
}
DistanceCalculator::DistanceCalculator(const tensor::ITensorAttribute& attr_tensor,
- const vespalib::eval::Value& query_tensor_in,
- const DistanceFunction& function_in)
+ BoundDistanceFunction::UP function_in)
: _attr_tensor(attr_tensor),
_query_tensor_uptr(),
- _query_tensor(&query_tensor_in),
- _query_tensor_cells(_query_tensor->cells()),
- _dist_fun_uptr(),
- _dist_fun(&function_in)
+ _query_tensor(nullptr),
+ _dist_fun(std::move(function_in))
{
}
diff --git a/searchlib/src/vespa/searchlib/tensor/distance_calculator.h b/searchlib/src/vespa/searchlib/tensor/distance_calculator.h
index 6b4cf142264..a3ca771e30c 100644
--- a/searchlib/src/vespa/searchlib/tensor/distance_calculator.h
+++ b/searchlib/src/vespa/searchlib/tensor/distance_calculator.h
@@ -2,6 +2,7 @@
#pragma once
#include "distance_function.h"
+#include "distance_function_factory.h"
#include "i_tensor_attribute.h"
#include "vector_bundle.h"
#include <optional>
@@ -23,9 +24,7 @@ private:
const tensor::ITensorAttribute& _attr_tensor;
std::unique_ptr<vespalib::eval::Value> _query_tensor_uptr;
const vespalib::eval::Value* _query_tensor;
- vespalib::eval::TypedCells _query_tensor_cells;
- std::unique_ptr<DistanceFunction> _dist_fun_uptr;
- const DistanceFunction* _dist_fun;
+ std::unique_ptr<BoundDistanceFunction> _dist_fun;
public:
DistanceCalculator(const tensor::ITensorAttribute& attr_tensor,
@@ -35,20 +34,22 @@ public:
* Only used by unit tests where ownership of query tensor and distance function is handled outside.
*/
DistanceCalculator(const tensor::ITensorAttribute& attr_tensor,
- const vespalib::eval::Value& query_tensor_in,
- const DistanceFunction& function_in);
+ BoundDistanceFunction::UP function_in);
~DistanceCalculator();
const tensor::ITensorAttribute& attribute_tensor() const { return _attr_tensor; }
- const vespalib::eval::Value& query_tensor() const { return *_query_tensor; }
- const DistanceFunction& function() const { return *_dist_fun; }
+ const vespalib::eval::Value& query_tensor() const {
+ assert(_query_tensor != nullptr);
+ return *_query_tensor;
+ }
+ const BoundDistanceFunction& function() const { return *_dist_fun; }
double calc_raw_score(uint32_t docid) const {
auto vectors = _attr_tensor.get_vectors(docid);
double result = 0.0;
for (uint32_t i = 0; i < vectors.subspaces(); ++i) {
- double distance = _dist_fun->calc(_query_tensor_cells, vectors.cells(i));
+ double distance = _dist_fun->calc(vectors.cells(i));
double score = _dist_fun->to_rawscore(distance);
result = std::max(result, score);
}
@@ -59,7 +60,7 @@ public:
auto vectors = _attr_tensor.get_vectors(docid);
double result = std::numeric_limits<double>::max();
for (uint32_t i = 0; i < vectors.subspaces(); ++i) {
- double distance = _dist_fun->calc_with_limit(_query_tensor_cells, vectors.cells(i), limit);
+ double distance = _dist_fun->calc_with_limit(vectors.cells(i), limit);
result = std::min(result, distance);
}
return result;
@@ -67,7 +68,7 @@ public:
void calc_closest_subspace(VectorBundle vectors, std::optional<uint32_t>& closest_subspace, double& best_distance) {
for (uint32_t i = 0; i < vectors.subspaces(); ++i) {
- double distance = _dist_fun->calc(_query_tensor_cells, vectors.cells(i));
+ double distance = _dist_fun->calc(vectors.cells(i));
if (!closest_subspace.has_value() || distance < best_distance) {
best_distance = distance;
closest_subspace = i;
diff --git a/searchlib/src/vespa/searchlib/tensor/distance_function.h b/searchlib/src/vespa/searchlib/tensor/distance_function.h
index d5ebf656189..443191a272c 100644
--- a/searchlib/src/vespa/searchlib/tensor/distance_function.h
+++ b/searchlib/src/vespa/searchlib/tensor/distance_function.h
@@ -9,13 +9,24 @@ namespace vespalib::eval { struct TypedCells; }
namespace search::tensor {
+class DistanceConverter {
+public:
+ virtual ~DistanceConverter() = default;
+
+ // convert threshold (external distance units) to internal units
+ virtual double convert_threshold(double threshold) const = 0;
+
+ // convert internal distance to rawscore (1.0 / (1.0 + d))
+ virtual double to_rawscore(double distance) const = 0;
+};
+
/**
* Interface used to calculate the distance between two n-dimensional vectors.
*
* The vectors must be of same size and same cell type (float or double).
* The actual implementation must know which type the vectors are.
*/
-class DistanceFunction {
+class DistanceFunction : public DistanceConverter {
private:
vespalib::eval::CellType _expect_cell_type;
public:
@@ -33,12 +44,6 @@ public:
// calculate internal distance (comparable)
virtual double calc(const vespalib::eval::TypedCells& lhs, const vespalib::eval::TypedCells& rhs) const = 0;
- // convert threshold (external distance units) to internal units
- virtual double convert_threshold(double threshold) const = 0;
-
- // convert internal distance to rawscore (1.0 / (1.0 + d))
- virtual double to_rawscore(double distance) const = 0;
-
// calculate internal distance, early return allowed if > limit
virtual double calc_with_limit(const vespalib::eval::TypedCells& lhs,
const vespalib::eval::TypedCells& rhs,
diff --git a/searchlib/src/vespa/searchlib/tensor/distance_function_factory.cpp b/searchlib/src/vespa/searchlib/tensor/distance_function_factory.cpp
index 96dfc580d87..f96715bcf60 100644
--- a/searchlib/src/vespa/searchlib/tensor/distance_function_factory.cpp
+++ b/searchlib/src/vespa/searchlib/tensor/distance_function_factory.cpp
@@ -3,6 +3,8 @@
#include "distance_function_factory.h"
#include "distance_functions.h"
#include <vespa/vespalib/util/typify.h>
+#include <vespa/vespalib/util/array.h>
+#include <vespa/vespalib/util/arrayref.h>
#include <vespa/log/log.h>
LOG_SETUP(".searchlib.tensor.distance_function_factory");
@@ -21,9 +23,9 @@ make_distance_function(DistanceMetric variant, CellType cell_type)
switch (cell_type) {
case CellType::FLOAT: return std::make_unique<SquaredEuclideanDistanceHW<float>>();
case CellType::DOUBLE: return std::make_unique<SquaredEuclideanDistanceHW<double>>();
- case CellType::INT8: return std::make_unique<SquaredEuclideanDistanceHW<vespalib::eval::Int8Float>>();
+ case CellType::INT8: return std::make_unique<SquaredEuclideanDistanceHW<vespalib::eval::Int8Float>>();
default: return std::make_unique<SquaredEuclideanDistance>(CellType::FLOAT);
- }
+ }
case DistanceMetric::Angular:
switch (cell_type) {
case CellType::FLOAT: return std::make_unique<AngularDistanceHW<float>>();
@@ -45,4 +47,54 @@ make_distance_function(DistanceMetric variant, CellType cell_type)
return DistanceFunction::UP();
}
+
+class SimpleBoundDistanceFunction : public BoundDistanceFunction {
+ const vespalib::eval::TypedCells _lhs;
+ const DistanceFunction &_df;
+public:
+ SimpleBoundDistanceFunction(const vespalib::eval::TypedCells& lhs,
+ const DistanceFunction &df)
+ : BoundDistanceFunction(lhs.type),
+ _lhs(lhs),
+ _df(df)
+ {}
+
+ double calc(const vespalib::eval::TypedCells& rhs) const override {
+ return _df.calc(_lhs, rhs);
+ }
+ double convert_threshold(double threshold) const override {
+ return _df.convert_threshold(threshold);
+ }
+ double to_rawscore(double distance) const override {
+ return _df.to_rawscore(distance);
+ }
+ double calc_with_limit(const vespalib::eval::TypedCells& rhs, double limit) const override {
+ return _df.calc_with_limit(_lhs, rhs, limit);
+ }
+};
+
+class SimpleDistanceFunctionFactory : public DistanceFunctionFactory {
+ DistanceFunction::UP _df;
+public:
+ SimpleDistanceFunctionFactory(DistanceFunction::UP df)
+ : DistanceFunctionFactory(df->expected_cell_type()),
+ _df(std::move(df))
+ {}
+
+ BoundDistanceFunction::UP for_query_vector(const vespalib::eval::TypedCells& lhs) override {
+ return std::make_unique<SimpleBoundDistanceFunction>(lhs, *_df);
+ }
+ BoundDistanceFunction::UP for_insertion_vector(const vespalib::eval::TypedCells& lhs) override {
+ return std::make_unique<SimpleBoundDistanceFunction>(lhs, *_df);
+ }
+};
+
+std::unique_ptr<DistanceFunctionFactory>
+make_distance_function_factory(search::attribute::DistanceMetric variant,
+ vespalib::eval::CellType cell_type)
+{
+ auto df = make_distance_function(variant, cell_type);
+ return std::make_unique<SimpleDistanceFunctionFactory>(std::move(df));
+}
+
}
diff --git a/searchlib/src/vespa/searchlib/tensor/distance_function_factory.h b/searchlib/src/vespa/searchlib/tensor/distance_function_factory.h
index 2d7eb4e73c1..1edb94bd7aa 100644
--- a/searchlib/src/vespa/searchlib/tensor/distance_function_factory.h
+++ b/searchlib/src/vespa/searchlib/tensor/distance_function_factory.h
@@ -3,12 +3,27 @@
#pragma once
#include "distance_function.h"
+#include "bound_distance_function.h"
#include <vespa/eval/eval/value_type.h>
#include <vespa/searchcommon/attribute/distance_metric.h>
namespace search::tensor {
/**
+ * API for binding the LHS of a distance calculation
+ * This allows keeping global state in the factory itself, and state
+ * for one particular vector in the distance function object.
+ */
+struct DistanceFunctionFactory {
+ const vespalib::eval::CellType expected_cell_type;
+ DistanceFunctionFactory(vespalib::eval::CellType ct) : expected_cell_type(ct) {}
+ virtual ~DistanceFunctionFactory() {}
+ virtual BoundDistanceFunction::UP for_query_vector(const vespalib::eval::TypedCells& lhs) = 0;
+ virtual BoundDistanceFunction::UP for_insertion_vector(const vespalib::eval::TypedCells& lhs) = 0;
+ using UP = std::unique_ptr<DistanceFunctionFactory>;
+};
+
+/**
* Create a distance function object customized for the given metric
* variant and cell type.
**/
@@ -16,4 +31,8 @@ DistanceFunction::UP
make_distance_function(search::attribute::DistanceMetric variant,
vespalib::eval::CellType cell_type);
+DistanceFunctionFactory::UP
+make_distance_function_factory(search::attribute::DistanceMetric variant,
+ vespalib::eval::CellType cell_type);
+
}
diff --git a/searchlib/src/vespa/searchlib/tensor/hnsw_index.cpp b/searchlib/src/vespa/searchlib/tensor/hnsw_index.cpp
index af332189b61..fa7f150fd89 100644
--- a/searchlib/src/vespa/searchlib/tensor/hnsw_index.cpp
+++ b/searchlib/src/vespa/searchlib/tensor/hnsw_index.cpp
@@ -278,23 +278,25 @@ double
HnswIndex<type>::calc_distance(uint32_t lhs_nodeid, uint32_t rhs_nodeid) const
{
auto lhs = get_vector(lhs_nodeid);
- return calc_distance(lhs, rhs_nodeid);
+ auto df = _distance_ff->for_insertion_vector(lhs);
+ auto rhs = get_vector(rhs_nodeid);
+ return df->calc(rhs);
}
template <HnswIndexType type>
double
-HnswIndex<type>::calc_distance(const TypedCells& lhs, uint32_t rhs_nodeid) const
+HnswIndex<type>::calc_distance(const BoundDistanceFunction &df, uint32_t rhs_nodeid) const
{
auto rhs = get_vector(rhs_nodeid);
- return _distance_func->calc(lhs, rhs);
+ return df.calc(rhs);
}
template <HnswIndexType type>
double
-HnswIndex<type>::calc_distance(const TypedCells& lhs, uint32_t rhs_docid, uint32_t rhs_subspace) const
+HnswIndex<type>::calc_distance(const BoundDistanceFunction &df, uint32_t rhs_docid, uint32_t rhs_subspace) const
{
auto rhs = get_vector(rhs_docid, rhs_subspace);
- return _distance_func->calc(lhs, rhs);
+ return df.calc(rhs);
}
template <HnswIndexType type>
@@ -323,7 +325,9 @@ HnswIndex<type>::estimate_visited_nodes(uint32_t level, uint32_t nodeid_limit, u
template <HnswIndexType type>
HnswCandidate
-HnswIndex<type>::find_nearest_in_layer(const TypedCells& input, const HnswCandidate& entry_point, uint32_t level) const
+HnswIndex<type>::find_nearest_in_layer(
+ const BoundDistanceFunction &df,
+ const HnswCandidate& entry_point, uint32_t level) const
{
HnswCandidate nearest = entry_point;
bool keep_searching = true;
@@ -334,7 +338,7 @@ HnswIndex<type>::find_nearest_in_layer(const TypedCells& input, const HnswCandid
auto neighbor_ref = neighbor_node.levels_ref().load_acquire();
uint32_t neighbor_docid = acquire_docid(neighbor_node, neighbor_nodeid);
uint32_t neighbor_subspace = neighbor_node.acquire_subspace();
- double dist = calc_distance(input, neighbor_docid, neighbor_subspace);
+ double dist = calc_distance(df, neighbor_docid, neighbor_subspace);
if (_graph.still_valid(neighbor_nodeid, neighbor_ref)
&& dist < nearest.distance)
{
@@ -349,9 +353,11 @@ HnswIndex<type>::find_nearest_in_layer(const TypedCells& input, const HnswCandid
template <HnswIndexType type>
template <class VisitedTracker, class BestNeighbors>
void
-HnswIndex<type>::search_layer_helper(const TypedCells& input, uint32_t neighbors_to_find,
- BestNeighbors& best_neighbors, uint32_t level, const GlobalFilter *filter,
- uint32_t nodeid_limit, uint32_t estimated_visited_nodes) const
+HnswIndex<type>::search_layer_helper(
+ const BoundDistanceFunction &df,
+ uint32_t neighbors_to_find,
+ BestNeighbors& best_neighbors, uint32_t level, const GlobalFilter *filter,
+ uint32_t nodeid_limit, uint32_t estimated_visited_nodes) const
{
NearestPriQ candidates;
GlobalFilterWrapper<type> filter_wrapper(filter);
@@ -389,7 +395,7 @@ HnswIndex<type>::search_layer_helper(const TypedCells& input, uint32_t neighbors
}
uint32_t neighbor_docid = acquire_docid(neighbor_node, neighbor_nodeid);
uint32_t neighbor_subspace = neighbor_node.acquire_subspace();
- double dist_to_input = calc_distance(input, neighbor_docid, neighbor_subspace);
+ double dist_to_input = calc_distance(df, neighbor_docid, neighbor_subspace);
if (dist_to_input < limit_dist) {
candidates.emplace(neighbor_nodeid, neighbor_ref, dist_to_input);
if (filter_wrapper.check(neighbor_docid)) {
@@ -407,29 +413,31 @@ HnswIndex<type>::search_layer_helper(const TypedCells& input, uint32_t neighbors
template <HnswIndexType type>
template <class BestNeighbors>
void
-HnswIndex<type>::search_layer(const TypedCells& input, uint32_t neighbors_to_find,
- BestNeighbors& best_neighbors, uint32_t level, const GlobalFilter *filter) const
+HnswIndex<type>::search_layer(
+ const BoundDistanceFunction &df,
+ uint32_t neighbors_to_find,
+ BestNeighbors& best_neighbors, uint32_t level, const GlobalFilter *filter) const
{
uint32_t nodeid_limit = _graph.nodes_size.load(std::memory_order_acquire);
uint32_t estimated_visited_nodes = estimate_visited_nodes(level, nodeid_limit, neighbors_to_find, filter);
if (estimated_visited_nodes >= nodeid_limit / 128) {
- search_layer_helper<BitVectorVisitedTracker>(input, neighbors_to_find, best_neighbors, level, filter, nodeid_limit, estimated_visited_nodes);
+ search_layer_helper<BitVectorVisitedTracker>(df, neighbors_to_find, best_neighbors, level, filter, nodeid_limit, estimated_visited_nodes);
} else {
- search_layer_helper<HashSetVisitedTracker>(input, neighbors_to_find, best_neighbors, level, filter, nodeid_limit, estimated_visited_nodes);
+ search_layer_helper<HashSetVisitedTracker>(df, neighbors_to_find, best_neighbors, level, filter, nodeid_limit, estimated_visited_nodes);
}
}
template <HnswIndexType type>
-HnswIndex<type>::HnswIndex(const DocVectorAccess& vectors, DistanceFunction::UP distance_func,
+HnswIndex<type>::HnswIndex(const DocVectorAccess& vectors, DistanceFunctionFactory::UP distance_ff,
RandomLevelGenerator::UP level_generator, const HnswIndexConfig& cfg)
: _graph(),
_vectors(vectors),
- _distance_func(std::move(distance_func)),
+ _distance_ff(std::move(distance_ff)),
_level_generator(std::move(level_generator)),
_id_mapping(),
_cfg(cfg)
{
- assert(_distance_func);
+ assert(_distance_ff);
}
template <HnswIndexType type>
@@ -483,12 +491,13 @@ HnswIndex<type>::internal_prepare_add_node(PreparedAddDoc& op, TypedCells input_
return;
}
int search_level = entry.level;
- double entry_dist = calc_distance(input_vector, entry.nodeid);
+ auto df = _distance_ff->for_insertion_vector(input_vector);
+ double entry_dist = calc_distance(*df, entry.nodeid);
uint32_t entry_docid = get_docid(entry.nodeid);
// TODO: check if entry nodeid/levels_ref is still valid here
HnswCandidate entry_point(entry.nodeid, entry_docid, entry.levels_ref, entry_dist);
while (search_level > node_max_level) {
- entry_point = find_nearest_in_layer(input_vector, entry_point, search_level);
+ entry_point = find_nearest_in_layer(*df, entry_point, search_level);
--search_level;
}
@@ -497,7 +506,7 @@ HnswIndex<type>::internal_prepare_add_node(PreparedAddDoc& op, TypedCells input_
search_level = std::min(node_max_level, search_level);
// Find neighbors of the added document in each level it should exist in.
while (search_level >= 0) {
- search_layer(input_vector, _cfg.neighbors_to_explore_at_construction(), best_neighbors, search_level);
+ search_layer(*df, _cfg.neighbors_to_explore_at_construction(), best_neighbors, search_level);
auto neighbors = select_neighbors(best_neighbors.peek(), _cfg.max_links_on_inserts());
auto& links = connections[search_level];
links.reserve(neighbors.used.size());
@@ -850,11 +859,13 @@ struct NeighborsByDocId {
template <HnswIndexType type>
std::vector<NearestNeighborIndex::Neighbor>
-HnswIndex<type>::top_k_by_docid(uint32_t k, TypedCells vector,
- const GlobalFilter *filter, uint32_t explore_k,
- double distance_threshold) const
+HnswIndex<type>::top_k_by_docid(
+ uint32_t k,
+ const BoundDistanceFunction &df,
+ const GlobalFilter *filter, uint32_t explore_k,
+ double distance_threshold) const
{
- SearchBestNeighbors candidates = top_k_candidates(vector, std::max(k, explore_k), filter);
+ SearchBestNeighbors candidates = top_k_candidates(df, std::max(k, explore_k), filter);
auto result = candidates.get_neighbors(k, distance_threshold);
std::sort(result.begin(), result.end(), NeighborsByDocId());
return result;
@@ -862,24 +873,31 @@ HnswIndex<type>::top_k_by_docid(uint32_t k, TypedCells vector,
template <HnswIndexType type>
std::vector<NearestNeighborIndex::Neighbor>
-HnswIndex<type>::find_top_k(uint32_t k, TypedCells vector, uint32_t explore_k,
- double distance_threshold) const
+HnswIndex<type>::find_top_k(
+ uint32_t k,
+ const BoundDistanceFunction &df,
+ uint32_t explore_k,
+ double distance_threshold) const
{
- return top_k_by_docid(k, vector, nullptr, explore_k, distance_threshold);
+ return top_k_by_docid(k, df, nullptr, explore_k, distance_threshold);
}
template <HnswIndexType type>
std::vector<NearestNeighborIndex::Neighbor>
-HnswIndex<type>::find_top_k_with_filter(uint32_t k, TypedCells vector,
- const GlobalFilter &filter, uint32_t explore_k,
- double distance_threshold) const
+HnswIndex<type>::find_top_k_with_filter(
+ uint32_t k,
+ const BoundDistanceFunction &df,
+ const GlobalFilter &filter, uint32_t explore_k,
+ double distance_threshold) const
{
- return top_k_by_docid(k, vector, &filter, explore_k, distance_threshold);
+ return top_k_by_docid(k, df, &filter, explore_k, distance_threshold);
}
template <HnswIndexType type>
typename HnswIndex<type>::SearchBestNeighbors
-HnswIndex<type>::top_k_candidates(const TypedCells &vector, uint32_t k, const GlobalFilter *filter) const
+HnswIndex<type>::top_k_candidates(
+ const BoundDistanceFunction &df,
+ uint32_t k, const GlobalFilter *filter) const
{
SearchBestNeighbors best_neighbors;
auto entry = _graph.get_entry_node();
@@ -888,16 +906,16 @@ HnswIndex<type>::top_k_candidates(const TypedCells &vector, uint32_t k, const Gl
return best_neighbors;
}
int search_level = entry.level;
- double entry_dist = calc_distance(vector, entry.nodeid);
+ double entry_dist = calc_distance(df, entry.nodeid);
uint32_t entry_docid = get_docid(entry.nodeid);
// TODO: check if entry docid/levels_ref is still valid here
HnswCandidate entry_point(entry.nodeid, entry_docid, entry.levels_ref, entry_dist);
while (search_level > 0) {
- entry_point = find_nearest_in_layer(vector, entry_point, search_level);
+ entry_point = find_nearest_in_layer(df, entry_point, search_level);
--search_level;
}
best_neighbors.push(entry_point);
- search_layer(vector, k, best_neighbors, 0, filter);
+ search_layer(df, k, best_neighbors, 0, filter);
return best_neighbors;
}
diff --git a/searchlib/src/vespa/searchlib/tensor/hnsw_index.h b/searchlib/src/vespa/searchlib/tensor/hnsw_index.h
index 984acc6c9a1..0809dcf4fe3 100644
--- a/searchlib/src/vespa/searchlib/tensor/hnsw_index.h
+++ b/searchlib/src/vespa/searchlib/tensor/hnsw_index.h
@@ -4,6 +4,7 @@
#include "hnsw_index_config.h"
#include "distance_function.h"
+#include "distance_function_factory.h"
#include "doc_vector_access.h"
#include "hnsw_identity_mapping.h"
#include "hnsw_index_utils.h"
@@ -104,7 +105,7 @@ protected:
GraphType _graph;
const DocVectorAccess& _vectors;
- DistanceFunction::UP _distance_func;
+ std::unique_ptr<DistanceFunctionFactory> _distance_ff;
RandomLevelGenerator::UP _level_generator;
IdMapping _id_mapping; // mapping from docid to nodeid vector
HnswIndexConfig _cfg;
@@ -158,23 +159,23 @@ protected:
}
double calc_distance(uint32_t lhs_nodeid, uint32_t rhs_nodeid) const;
- double calc_distance(const TypedCells& lhs, uint32_t rhs_nodeid) const;
- double calc_distance(const TypedCells& lhs, uint32_t rhs_docid, uint32_t rhs_subspace) const;
+ double calc_distance(const BoundDistanceFunction &df, uint32_t rhs_nodeid) const;
+ double calc_distance(const BoundDistanceFunction &df, uint32_t rhs_docid, uint32_t rhs_subspace) const;
uint32_t estimate_visited_nodes(uint32_t level, uint32_t nodeid_limit, uint32_t neighbors_to_find, const GlobalFilter* filter) const;
/**
* Performs a greedy search in the given layer to find the candidate that is nearest the input vector.
*/
- HnswCandidate find_nearest_in_layer(const TypedCells& input, const HnswCandidate& entry_point, uint32_t level) const;
+ HnswCandidate find_nearest_in_layer(const BoundDistanceFunction &df, const HnswCandidate& entry_point, uint32_t level) const;
template <class VisitedTracker, class BestNeighbors>
- void search_layer_helper(const TypedCells& input, uint32_t neighbors_to_find, BestNeighbors& best_neighbors,
+ void search_layer_helper(const BoundDistanceFunction &df, uint32_t neighbors_to_find, BestNeighbors& best_neighbors,
uint32_t level, const GlobalFilter *filter,
uint32_t nodeid_limit,
uint32_t estimated_visited_nodes) const;
template <class BestNeighbors>
- void search_layer(const TypedCells& input, uint32_t neighbors_to_find, BestNeighbors& best_neighbors,
+ void search_layer(const BoundDistanceFunction &df, uint32_t neighbors_to_find, BestNeighbors& best_neighbors,
uint32_t level, const GlobalFilter *filter = nullptr) const;
- std::vector<Neighbor> top_k_by_docid(uint32_t k, TypedCells vector,
+ std::vector<Neighbor> top_k_by_docid(uint32_t k, const BoundDistanceFunction &df,
const GlobalFilter *filter, uint32_t explore_k,
double distance_threshold) const;
@@ -185,7 +186,7 @@ protected:
void internal_complete_add(uint32_t docid, internal::PreparedAddDoc &op);
void internal_complete_add_node(uint32_t nodeid, uint32_t docid, uint32_t subspace, internal::PreparedAddNode &prepared_node);
public:
- HnswIndex(const DocVectorAccess& vectors, DistanceFunction::UP distance_func,
+ HnswIndex(const DocVectorAccess& vectors, DistanceFunctionFactory::UP distance_ff,
RandomLevelGenerator::UP level_generator, const HnswIndexConfig& cfg);
~HnswIndex() override;
@@ -213,14 +214,23 @@ public:
std::unique_ptr<NearestNeighborIndexSaver> make_saver() const override;
std::unique_ptr<NearestNeighborIndexLoader> make_loader(FastOS_FileInterface& file) override;
- std::vector<Neighbor> find_top_k(uint32_t k, TypedCells vector, uint32_t explore_k,
- double distance_threshold) const override;
- std::vector<Neighbor> find_top_k_with_filter(uint32_t k, TypedCells vector,
- const GlobalFilter &filter, uint32_t explore_k,
- double distance_threshold) const override;
- const DistanceFunction *distance_function() const override { return _distance_func.get(); }
+ std::vector<Neighbor> find_top_k(
+ uint32_t k,
+ const BoundDistanceFunction &df,
+ uint32_t explore_k,
+ double distance_threshold) const override;
- SearchBestNeighbors top_k_candidates(const TypedCells &vector, uint32_t k, const GlobalFilter *filter) const;
+ std::vector<Neighbor> find_top_k_with_filter(
+ uint32_t k,
+ const BoundDistanceFunction &df,
+ const GlobalFilter &filter, uint32_t explore_k,
+ double distance_threshold) const override;
+
+ DistanceFunctionFactory &distance_function_factory() const override { return *_distance_ff; }
+
+ SearchBestNeighbors top_k_candidates(
+ const BoundDistanceFunction &df,
+ uint32_t k, const GlobalFilter *filter) const;
uint32_t get_entry_nodeid() const { return _graph.get_entry_node().nodeid; }
int32_t get_entry_level() const { return _graph.get_entry_node().level; }
diff --git a/searchlib/src/vespa/searchlib/tensor/i_tensor_attribute.h b/searchlib/src/vespa/searchlib/tensor/i_tensor_attribute.h
index ec6774c9517..b734663a6f4 100644
--- a/searchlib/src/vespa/searchlib/tensor/i_tensor_attribute.h
+++ b/searchlib/src/vespa/searchlib/tensor/i_tensor_attribute.h
@@ -12,6 +12,7 @@ namespace vespalib::slime { struct Inserter; }
namespace search::tensor {
+struct DistanceFunctionFactory;
class NearestNeighborIndex;
class SerializedTensorRef;
@@ -32,6 +33,7 @@ public:
virtual const vespalib::eval::ValueType & getTensorType() const = 0;
+ virtual DistanceFunctionFactory& distance_function_factory() const = 0;
virtual const NearestNeighborIndex* nearest_neighbor_index() const { return nullptr; }
using DistanceMetric = search::attribute::DistanceMetric;
virtual DistanceMetric distance_metric() const = 0;
diff --git a/searchlib/src/vespa/searchlib/tensor/imported_tensor_attribute_vector_read_guard.h b/searchlib/src/vespa/searchlib/tensor/imported_tensor_attribute_vector_read_guard.h
index 4e1cc9efd96..0fb0fd1bf78 100644
--- a/searchlib/src/vespa/searchlib/tensor/imported_tensor_attribute_vector_read_guard.h
+++ b/searchlib/src/vespa/searchlib/tensor/imported_tensor_attribute_vector_read_guard.h
@@ -38,6 +38,9 @@ public:
SerializedTensorRef get_serialized_tensor_ref(uint32_t docid) const override;
bool supports_extract_cells_ref() const override { return _target_tensor_attribute.supports_extract_cells_ref(); }
bool supports_get_tensor_ref() const override { return _target_tensor_attribute.supports_get_tensor_ref(); }
+ DistanceFunctionFactory& distance_function_factory() const override {
+ return _target_tensor_attribute.distance_function_factory();
+ }
DistanceMetric distance_metric() const override { return _target_tensor_attribute.distance_metric(); }
bool supports_get_serialized_tensor_ref() const override;
uint32_t get_num_docs() const override { return getNumDocs(); }
diff --git a/searchlib/src/vespa/searchlib/tensor/nearest_neighbor_index.h b/searchlib/src/vespa/searchlib/tensor/nearest_neighbor_index.h
index 3f6c9b82a65..4b7b934fee0 100644
--- a/searchlib/src/vespa/searchlib/tensor/nearest_neighbor_index.h
+++ b/searchlib/src/vespa/searchlib/tensor/nearest_neighbor_index.h
@@ -3,6 +3,7 @@
#pragma once
#include "distance_function.h"
+#include "distance_function_factory.h"
#include "prepare_result.h"
#include "vector_bundle.h"
#include <vespa/vespalib/util/generationhandler.h>
@@ -97,18 +98,18 @@ public:
virtual std::unique_ptr<NearestNeighborIndexLoader> make_loader(FastOS_FileInterface& file) = 0;
virtual std::vector<Neighbor> find_top_k(uint32_t k,
- vespalib::eval::TypedCells vector,
+ const BoundDistanceFunction &df,
uint32_t explore_k,
double distance_threshold) const = 0;
// only return neighbors where the corresponding filter bit is set
virtual std::vector<Neighbor> find_top_k_with_filter(uint32_t k,
- vespalib::eval::TypedCells vector,
+ const BoundDistanceFunction &df,
const GlobalFilter &filter,
uint32_t explore_k,
double distance_threshold) const = 0;
- virtual const DistanceFunction *distance_function() const = 0;
+ virtual DistanceFunctionFactory &distance_function_factory() const = 0;
};
}
diff --git a/searchlib/src/vespa/searchlib/tensor/tensor_attribute.cpp b/searchlib/src/vespa/searchlib/tensor/tensor_attribute.cpp
index 1e388199ef8..5e554f76779 100644
--- a/searchlib/src/vespa/searchlib/tensor/tensor_attribute.cpp
+++ b/searchlib/src/vespa/searchlib/tensor/tensor_attribute.cpp
@@ -58,6 +58,7 @@ TensorAttribute::TensorAttribute(vespalib::stringref name, const Config &cfg, Te
: NotImplementedAttribute(name, cfg),
_refVector(cfg.getGrowStrategy(), getGenerationHolder()),
_tensorStore(tensorStore),
+ _distance_function_factory(make_distance_function_factory(cfg.distance_metric(), cfg.tensorType().cell_type())),
_index(),
_is_dense(cfg.tensorType().is_dense()),
_emptyTensor(createEmptyTensor(cfg.tensorType())),
@@ -280,6 +281,13 @@ TensorAttribute::getTensorType() const
return getConfig().tensorType();
}
+DistanceFunctionFactory&
+TensorAttribute::distance_function_factory() const
+{
+ return *_distance_function_factory;
+
+}
+
const NearestNeighborIndex*
TensorAttribute::nearest_neighbor_index() const
{
diff --git a/searchlib/src/vespa/searchlib/tensor/tensor_attribute.h b/searchlib/src/vespa/searchlib/tensor/tensor_attribute.h
index 4cb903c6c67..f629562a34d 100644
--- a/searchlib/src/vespa/searchlib/tensor/tensor_attribute.h
+++ b/searchlib/src/vespa/searchlib/tensor/tensor_attribute.h
@@ -28,6 +28,7 @@ protected:
RefVector _refVector; // docId -> ref in data store for serialized tensor
TensorStore &_tensorStore; // data store for serialized tensors
+ std::unique_ptr<DistanceFunctionFactory> _distance_function_factory;
std::unique_ptr<NearestNeighborIndex> _index;
bool _is_dense;
std::unique_ptr<vespalib::eval::Value> _emptyTensor;
@@ -67,6 +68,7 @@ public:
bool supports_get_tensor_ref() const override { return false; }
bool supports_get_serialized_tensor_ref() const override;
const vespalib::eval::ValueType & getTensorType() const override;
+ DistanceFunctionFactory& distance_function_factory() const override;
const NearestNeighborIndex* nearest_neighbor_index() const override;
void get_state(const vespalib::slime::Inserter& inserter) const override;
void clearDocs(DocId lidLow, DocId lidLimit, bool in_shrink_lid_space) override;
diff --git a/searchlib/src/vespa/searchlib/tensor/tensor_ext_attribute.cpp b/searchlib/src/vespa/searchlib/tensor/tensor_ext_attribute.cpp
index 19c8cf6053b..f474d65a19d 100644
--- a/searchlib/src/vespa/searchlib/tensor/tensor_ext_attribute.cpp
+++ b/searchlib/src/vespa/searchlib/tensor/tensor_ext_attribute.cpp
@@ -36,6 +36,7 @@ TensorExtAttribute::TensorExtAttribute(const vespalib::string& name, const Confi
: NotImplementedAttribute(name, cfg),
ITensorAttribute(),
IExtendAttribute(),
+ _distance_function_factory(make_distance_function_factory(cfg.distance_metric(), cfg.tensorType().cell_type())),
_subspace_type(cfg.tensorType()),
_empty(_subspace_type),
_empty_tensor(create_empty_tensor(cfg.tensorType()))
diff --git a/searchlib/src/vespa/searchlib/tensor/tensor_ext_attribute.h b/searchlib/src/vespa/searchlib/tensor/tensor_ext_attribute.h
index a58426cd146..93d7a94c257 100644
--- a/searchlib/src/vespa/searchlib/tensor/tensor_ext_attribute.h
+++ b/searchlib/src/vespa/searchlib/tensor/tensor_ext_attribute.h
@@ -5,6 +5,7 @@
#include "i_tensor_attribute.h"
#include "empty_subspace.h"
#include "subspace_type.h"
+#include "distance_function_factory.h"
#include <vespa/searchlib/attribute/not_implemented_attribute.h>
#include <vespa/vespalib/stllike/allocator.h>
@@ -20,9 +21,12 @@ class TensorExtAttribute : public NotImplementedAttribute,
public IExtendAttribute
{
std::vector<const vespalib::eval::Value*> _data;
+ // XXX this should probably be longer-lived:
+ std::unique_ptr<DistanceFunctionFactory> _distance_function_factory;
SubspaceType _subspace_type;
EmptySubspace _empty;
std::unique_ptr<vespalib::eval::Value> _empty_tensor;
+
public:
TensorExtAttribute(const vespalib::string& name, const Config& cfg);
~TensorExtAttribute() override;
@@ -46,6 +50,9 @@ public:
bool supports_get_tensor_ref() const override;
bool supports_get_serialized_tensor_ref() const override;
const vespalib::eval::ValueType & getTensorType() const override;
+ DistanceFunctionFactory& distance_function_factory() const override {
+ return *_distance_function_factory;
+ }
search::attribute::DistanceMetric distance_metric() const override;
uint32_t get_num_docs() const override;
void get_state(const vespalib::slime::Inserter& inserter) const override;