diff options
author | Arne H Juul <arnej27959@users.noreply.github.com> | 2023-04-20 09:09:33 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-04-20 09:09:33 +0200 |
commit | 3cde18008cff2d1c812ec86141f538b56d5248ab (patch) | |
tree | eaca23b97d05abf775dfaf68b495cae70107c450 /searchlib | |
parent | 4bf83d5e87a8896ce3b6a14fb0889a2891053bf1 (diff) | |
parent | 732e4c4be8bbc5a43e3adae5db222301e630bd8c (diff) |
Merge pull request #26783 from vespa-engine/arnej/refactor-with-bound-distance
add mimimal version of BoundDistanceFunction
Diffstat (limited to 'searchlib')
25 files changed, 311 insertions, 121 deletions
diff --git a/searchlib/src/tests/attribute/tensorattribute/tensorattribute_test.cpp b/searchlib/src/tests/attribute/tensorattribute/tensorattribute_test.cpp index 28c50891225..e3c9e05073e 100644 --- a/searchlib/src/tests/attribute/tensorattribute/tensorattribute_test.cpp +++ b/searchlib/src/tests/attribute/tensorattribute/tensorattribute_test.cpp @@ -294,30 +294,33 @@ public: std::unique_ptr<NearestNeighborIndexLoader> make_loader(FastOS_FileInterface& file) override { return std::make_unique<MockIndexLoader>(_index_value, file); } - std::vector<Neighbor> find_top_k(uint32_t k, vespalib::eval::TypedCells vector, uint32_t explore_k, + std::vector<Neighbor> find_top_k(uint32_t k, + const search::tensor::BoundDistanceFunction &df, + uint32_t explore_k, double distance_threshold) const override { (void) k; - (void) vector; + (void) df; (void) explore_k; (void) distance_threshold; return std::vector<Neighbor>(); } - std::vector<Neighbor> find_top_k_with_filter(uint32_t k, vespalib::eval::TypedCells vector, + std::vector<Neighbor> find_top_k_with_filter(uint32_t k, + const search::tensor::BoundDistanceFunction &df, const GlobalFilter& filter, uint32_t explore_k, double distance_threshold) const override { (void) k; - (void) vector; + (void) df; (void) explore_k; (void) filter; (void) distance_threshold; return std::vector<Neighbor>(); } - const search::tensor::DistanceFunction *distance_function() const override { - static search::tensor::SquaredEuclideanDistance my_dist_fun(vespalib::eval::CellType::DOUBLE); - return &my_dist_fun; + search::tensor::DistanceFunctionFactory &distance_function_factory() const override { + static search::tensor::DistanceFunctionFactory::UP my_dist_fun = search::tensor::make_distance_function_factory(search::attribute::DistanceMetric::Euclidean, vespalib::eval::CellType::DOUBLE); + return *my_dist_fun; } }; diff --git a/searchlib/src/tests/queryeval/nearest_neighbor/nearest_neighbor_test.cpp b/searchlib/src/tests/queryeval/nearest_neighbor/nearest_neighbor_test.cpp index 2801bf90080..fd07529795a 100644 --- a/searchlib/src/tests/queryeval/nearest_neighbor/nearest_neighbor_test.cpp +++ b/searchlib/src/tests/queryeval/nearest_neighbor/nearest_neighbor_test.cpp @@ -134,7 +134,9 @@ SimpleResult find_matches(Fixture &env, const Value &qtv, double threshold = std auto md = MatchData::makeTestInstance(2, 2); auto &tfmd = *(md->resolveTermField(0)); auto &attr = *(env._attr); - DistanceCalculator dist_calc(attr, qtv, env.dist_fun()); + + auto dff = search::tensor::make_distance_function_factory(DistanceMetric::Euclidean, qtv.cells().type); + DistanceCalculator dist_calc(attr, dff->for_query_vector(qtv.cells())); NearestNeighborDistanceHeap dh(2); dh.set_distance_threshold(env.dist_fun().convert_threshold(threshold)); const GlobalFilter &filter = *env._global_filter; @@ -260,7 +262,8 @@ std::vector<feature_t> get_rawscores(Fixture &env, const Value &qtv) { auto md = MatchData::makeTestInstance(2, 2); auto &tfmd = *(md->resolveTermField(0)); auto &attr = *(env._attr); - DistanceCalculator dist_calc(attr, qtv, env.dist_fun()); + auto dff = search::tensor::make_distance_function_factory(DistanceMetric::Euclidean, qtv.cells().type); + DistanceCalculator dist_calc(attr, dff->for_query_vector(qtv.cells())); NearestNeighborDistanceHeap dh(2); auto dummy_filter = GlobalFilter::create(); auto search = NearestNeighborIterator::create(strict, tfmd, dist_calc, dh, *dummy_filter); @@ -333,7 +336,10 @@ TEST(NnsIndexIteratorTest, require_that_iterator_works_as_expected) { std::vector<NnsIndexIterator::Hit> hits{{2,4.0}, {3,9.0}, {5,1.0}, {8,16.0}, {9,36.0}}; auto md = MatchData::makeTestInstance(2, 2); auto &tfmd = *(md->resolveTermField(0)); - auto search = NnsIndexIterator::create(tfmd, hits, *euclid_d); + auto dff = search::tensor::make_distance_function_factory(DistanceMetric::Euclidean, CellType::DOUBLE); + vespalib::eval::TypedCells dummy; + auto df = dff->for_query_vector(dummy); + auto search = NnsIndexIterator::create(tfmd, hits, *df); search->initFullRange(); expect_not_match(*search, 1, 2); expect_match(*search, 2); diff --git a/searchlib/src/tests/tensor/hnsw_index/hnsw_index_test.cpp b/searchlib/src/tests/tensor/hnsw_index/hnsw_index_test.cpp index d9230849699..9f6216f5867 100644 --- a/searchlib/src/tests/tensor/hnsw_index/hnsw_index_test.cpp +++ b/searchlib/src/tests/tensor/hnsw_index/hnsw_index_test.cpp @@ -105,10 +105,16 @@ public: ~HnswIndexTest() {} + auto dff() { + return search::tensor::make_distance_function_factory( + search::attribute::DistanceMetric::Euclidean, + vespalib::eval::CellType::FLOAT); + } + void init(bool heuristic_select_neighbors) { auto generator = std::make_unique<LevelGenerator>(); level_generator = generator.get(); - index = std::make_unique<IndexType>(vectors, std::make_unique<SquaredEuclideanDistance>(vespalib::eval::CellType::FLOAT), + index = std::make_unique<IndexType>(vectors, dff(), std::move(generator), HnswIndexConfig(5, 2, 10, 0, heuristic_select_neighbors)); } @@ -165,9 +171,10 @@ public: uint32_t explore_k = 100; vespalib::ArrayRef qv_ref(qv); vespalib::eval::TypedCells qv_cells(qv_ref); + auto df = index->distance_function_factory().for_query_vector(qv_cells); auto got_by_docid = (global_filter->is_active()) ? - index->find_top_k_with_filter(k, qv_cells, *global_filter, explore_k, 10000.0) : - index->find_top_k(k, qv_cells, explore_k, 10000.0); + index->find_top_k_with_filter(k, *df, *global_filter, explore_k, 10000.0) : + index->find_top_k(k, *df, explore_k, 10000.0); std::vector<uint32_t> act; act.reserve(got_by_docid.size()); for (auto& hit : got_by_docid) { @@ -178,7 +185,8 @@ public: void expect_top_3(uint32_t docid, std::vector<uint32_t> exp_hits) { uint32_t k = 3; auto qv = vectors.get_vector(docid, 0); - auto rv = index->top_k_candidates(qv, k, global_filter->ptr_if_active()).peek(); + auto df = index->distance_function_factory().for_query_vector(qv); + auto rv = index->top_k_candidates(*df, k, global_filter->ptr_if_active()).peek(); std::sort(rv.begin(), rv.end(), LesserDistance()); size_t idx = 0; for (const auto & hit : rv) { @@ -189,7 +197,7 @@ public: if (exp_hits.size() == k) { std::vector<uint32_t> expected_by_docid = exp_hits; std::sort(expected_by_docid.begin(), expected_by_docid.end()); - auto got_by_docid = index->find_top_k(k, qv, k, 100100.25); + auto got_by_docid = index->find_top_k(k, *df, k, 100100.25); for (idx = 0; idx < k; ++idx) { EXPECT_EQ(expected_by_docid[idx], got_by_docid[idx].docid); } @@ -198,15 +206,16 @@ public: } void check_with_distance_threshold(uint32_t docid) { auto qv = vectors.get_vector(docid, 0); + auto df = index->distance_function_factory().for_query_vector(qv); uint32_t k = 3; - auto rv = index->top_k_candidates(qv, k, global_filter->ptr_if_active()).peek(); + auto rv = index->top_k_candidates(*df, k, global_filter->ptr_if_active()).peek(); std::sort(rv.begin(), rv.end(), LesserDistance()); EXPECT_EQ(rv.size(), 3); EXPECT_LE(rv[0].distance, rv[1].distance); double thr = (rv[0].distance + rv[1].distance) * 0.5; auto got_by_docid = (global_filter->is_active()) - ? index->find_top_k_with_filter(k, qv, *global_filter, k, thr) - : index->find_top_k(k, qv, k, thr); + ? index->find_top_k_with_filter(k, *df, *global_filter, k, thr) + : index->find_top_k(k, *df, k, thr); EXPECT_EQ(got_by_docid.size(), 1); EXPECT_EQ(got_by_docid[0].docid, index->get_docid(rv[0].nodeid)); for (const auto & hit : got_by_docid) { diff --git a/searchlib/src/tests/tensor/hnsw_index/stress_hnsw_mt.cpp b/searchlib/src/tests/tensor/hnsw_index/stress_hnsw_mt.cpp index ecf310798af..0dcd77ec392 100644 --- a/searchlib/src/tests/tensor/hnsw_index/stress_hnsw_mt.cpp +++ b/searchlib/src/tests/tensor/hnsw_index/stress_hnsw_mt.cpp @@ -261,9 +261,15 @@ public: ~Stressor() {} + auto dff() { + return search::tensor::make_distance_function_factory( + search::attribute::DistanceMetric::Euclidean, + vespalib::eval::CellType::FLOAT); + } + void init() { uint32_t m = 16; - index = std::make_unique<IndexType>(vectors, std::make_unique<FloatSqEuclideanDistance>(), + index = std::make_unique<IndexType>(vectors, dff(), std::make_unique<InvLogLevelGenerator>(m), HnswIndexConfig(2*m, m, 200, 10, true)); } diff --git a/searchlib/src/vespa/searchlib/queryeval/nearest_neighbor_blueprint.cpp b/searchlib/src/vespa/searchlib/queryeval/nearest_neighbor_blueprint.cpp index 7fdf5230325..7c307a1e35f 100644 --- a/searchlib/src/vespa/searchlib/queryeval/nearest_neighbor_blueprint.cpp +++ b/searchlib/src/vespa/searchlib/queryeval/nearest_neighbor_blueprint.cpp @@ -106,13 +106,13 @@ NearestNeighborBlueprint::set_global_filter(const GlobalFilter &global_filter, d void NearestNeighborBlueprint::perform_top_k(const search::tensor::NearestNeighborIndex* nns_index) { - auto lhs = _query_tensor.cells(); uint32_t k = _adjusted_target_hits; + const auto &df = _distance_calc->function(); if (_global_filter->is_active()) { - _found_hits = nns_index->find_top_k_with_filter(k, lhs, *_global_filter, k + _explore_additional_hits, _distance_threshold); + _found_hits = nns_index->find_top_k_with_filter(k, df, *_global_filter, k + _explore_additional_hits, _distance_threshold); _algorithm = Algorithm::INDEX_TOP_K_WITH_FILTER; } else { - _found_hits = nns_index->find_top_k(k, lhs, k + _explore_additional_hits, _distance_threshold); + _found_hits = nns_index->find_top_k(k, df, k + _explore_additional_hits, _distance_threshold); _algorithm = Algorithm::INDEX_TOP_K; } } diff --git a/searchlib/src/vespa/searchlib/queryeval/nns_index_iterator.cpp b/searchlib/src/vespa/searchlib/queryeval/nns_index_iterator.cpp index 95264a79431..5ec4357ca24 100644 --- a/searchlib/src/vespa/searchlib/queryeval/nns_index_iterator.cpp +++ b/searchlib/src/vespa/searchlib/queryeval/nns_index_iterator.cpp @@ -18,13 +18,13 @@ class NeighborVectorIterator : public NnsIndexIterator private: fef::TermFieldMatchData &_tfmd; const std::vector<Neighbor> &_hits; - const search::tensor::DistanceFunction &_dist_fun; + const search::tensor::BoundDistanceFunction &_dist_fun; uint32_t _idx; double _last_abstract_dist; public: NeighborVectorIterator(fef::TermFieldMatchData &tfmd, const std::vector<Neighbor> &hits, - const search::tensor::DistanceFunction &dist_fun) + const search::tensor::BoundDistanceFunction &dist_fun) : _tfmd(tfmd), _hits(hits), _dist_fun(dist_fun), @@ -65,7 +65,7 @@ std::unique_ptr<NnsIndexIterator> NnsIndexIterator::create( fef::TermFieldMatchData &tfmd, const std::vector<Neighbor> &hits, - const search::tensor::DistanceFunction &dist_fun) + const search::tensor::BoundDistanceFunction &dist_fun) { return std::make_unique<NeighborVectorIterator>(tfmd, hits, dist_fun); } diff --git a/searchlib/src/vespa/searchlib/queryeval/nns_index_iterator.h b/searchlib/src/vespa/searchlib/queryeval/nns_index_iterator.h index 031a603de49..84ff0f04813 100644 --- a/searchlib/src/vespa/searchlib/queryeval/nns_index_iterator.h +++ b/searchlib/src/vespa/searchlib/queryeval/nns_index_iterator.h @@ -16,7 +16,7 @@ public: static std::unique_ptr<NnsIndexIterator> create( fef::TermFieldMatchData &tfmd, const std::vector<Hit> &hits, - const search::tensor::DistanceFunction &dist_fun); + const search::tensor::BoundDistanceFunction &dist_fun); }; } // namespace diff --git a/searchlib/src/vespa/searchlib/tensor/CMakeLists.txt b/searchlib/src/vespa/searchlib/tensor/CMakeLists.txt index 313863d8dcb..090042e5b83 100644 --- a/searchlib/src/vespa/searchlib/tensor/CMakeLists.txt +++ b/searchlib/src/vespa/searchlib/tensor/CMakeLists.txt @@ -3,6 +3,7 @@ vespa_add_library(searchlib_tensor OBJECT SOURCES angular_distance.cpp bitvector_visited_tracker.cpp + bound_distance_function.cpp default_nearest_neighbor_index_factory.cpp dense_tensor_attribute.cpp dense_tensor_store.cpp diff --git a/searchlib/src/vespa/searchlib/tensor/bound_distance_function.cpp b/searchlib/src/vespa/searchlib/tensor/bound_distance_function.cpp new file mode 100644 index 00000000000..33b94e5218c --- /dev/null +++ b/searchlib/src/vespa/searchlib/tensor/bound_distance_function.cpp @@ -0,0 +1,3 @@ +// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. + +#include "bound_distance_function.h" diff --git a/searchlib/src/vespa/searchlib/tensor/bound_distance_function.h b/searchlib/src/vespa/searchlib/tensor/bound_distance_function.h new file mode 100644 index 00000000000..17e9e49cada --- /dev/null +++ b/searchlib/src/vespa/searchlib/tensor/bound_distance_function.h @@ -0,0 +1,44 @@ +// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. + +#pragma once + +#include <memory> +#include <vespa/eval/eval/cell_type.h> +#include <vespa/eval/eval/typed_cells.h> +#include <vespa/vespalib/util/array.h> +#include <vespa/vespalib/util/arrayref.h> +#include "distance_function.h" + +namespace vespalib::eval { struct TypedCells; } + +namespace search::tensor { + +/** + * Interface used to calculate the distance from a prebound n-dimensional vector. + * + * The actual implementation must know which type the vectors are. + */ +class BoundDistanceFunction : public DistanceConverter { +private: + vespalib::eval::CellType _expect_cell_type; +public: + using UP = std::unique_ptr<BoundDistanceFunction>; + + BoundDistanceFunction(vespalib::eval::CellType expected) : _expect_cell_type(expected) {} + + virtual ~BoundDistanceFunction() = default; + + // input vectors will be converted to this cell type: + vespalib::eval::CellType expected_cell_type() const { + return _expect_cell_type; + } + + // calculate internal distance (comparable) + virtual double calc(const vespalib::eval::TypedCells& rhs) const = 0; + + // calculate internal distance, early return allowed if > limit + virtual double calc_with_limit(const vespalib::eval::TypedCells& rhs, + double limit) const = 0; +}; + +} diff --git a/searchlib/src/vespa/searchlib/tensor/default_nearest_neighbor_index_factory.cpp b/searchlib/src/vespa/searchlib/tensor/default_nearest_neighbor_index_factory.cpp index 4f6f8ac5c87..77c912dc690 100644 --- a/searchlib/src/vespa/searchlib/tensor/default_nearest_neighbor_index_factory.cpp +++ b/searchlib/src/vespa/searchlib/tensor/default_nearest_neighbor_index_factory.cpp @@ -41,12 +41,12 @@ DefaultNearestNeighborIndexFactory::make(const DocVectorAccess& vectors, true); if (multi_vector_index) { return std::make_unique<HnswIndex<HnswIndexType::MULTI>>(vectors, - make_distance_function(params.distance_metric(), cell_type), + make_distance_function_factory(params.distance_metric(), cell_type), make_random_level_generator(m), cfg); } else { return std::make_unique<HnswIndex<HnswIndexType::SINGLE>>(vectors, - make_distance_function(params.distance_metric(), cell_type), + make_distance_function_factory(params.distance_metric(), cell_type), make_random_level_generator(m), cfg); } diff --git a/searchlib/src/vespa/searchlib/tensor/distance_calculator.cpp b/searchlib/src/vespa/searchlib/tensor/distance_calculator.cpp index b669b5ffea6..8da777d97eb 100644 --- a/searchlib/src/vespa/searchlib/tensor/distance_calculator.cpp +++ b/searchlib/src/vespa/searchlib/tensor/distance_calculator.cpp @@ -47,8 +47,6 @@ struct ConvertCellsSelector } }; - - } namespace search::tensor { @@ -58,36 +56,27 @@ DistanceCalculator::DistanceCalculator(const tensor::ITensorAttribute& attr_tens : _attr_tensor(attr_tensor), _query_tensor_uptr(), _query_tensor(&query_tensor_in), - _query_tensor_cells(), - _dist_fun_uptr(make_distance_function(_attr_tensor.distance_metric(), - _attr_tensor.getTensorType().cell_type())), - _dist_fun(_dist_fun_uptr.get()) + _dist_fun() { - assert(_dist_fun); - auto nns_index = _attr_tensor.nearest_neighbor_index(); - if (nns_index) { - _dist_fun = nns_index->distance_function(); - assert(_dist_fun); - } + auto * nns_index = _attr_tensor.nearest_neighbor_index(); + auto & dff = nns_index ? nns_index->distance_function_factory() : attr_tensor.distance_function_factory(); auto query_ct = _query_tensor->cells().type; - CellType required_ct = _dist_fun->expected_cell_type(); + CellType required_ct = dff.expected_cell_type; if (query_ct != required_ct) { ConvertCellsSelector converter; _query_tensor_uptr = converter(query_ct, required_ct, *_query_tensor); _query_tensor = _query_tensor_uptr.get(); } - _query_tensor_cells = _query_tensor->cells(); + _dist_fun = dff.for_query_vector(_query_tensor->cells()); + assert(_dist_fun); } DistanceCalculator::DistanceCalculator(const tensor::ITensorAttribute& attr_tensor, - const vespalib::eval::Value& query_tensor_in, - const DistanceFunction& function_in) + BoundDistanceFunction::UP function_in) : _attr_tensor(attr_tensor), _query_tensor_uptr(), - _query_tensor(&query_tensor_in), - _query_tensor_cells(_query_tensor->cells()), - _dist_fun_uptr(), - _dist_fun(&function_in) + _query_tensor(nullptr), + _dist_fun(std::move(function_in)) { } diff --git a/searchlib/src/vespa/searchlib/tensor/distance_calculator.h b/searchlib/src/vespa/searchlib/tensor/distance_calculator.h index 6b4cf142264..a3ca771e30c 100644 --- a/searchlib/src/vespa/searchlib/tensor/distance_calculator.h +++ b/searchlib/src/vespa/searchlib/tensor/distance_calculator.h @@ -2,6 +2,7 @@ #pragma once #include "distance_function.h" +#include "distance_function_factory.h" #include "i_tensor_attribute.h" #include "vector_bundle.h" #include <optional> @@ -23,9 +24,7 @@ private: const tensor::ITensorAttribute& _attr_tensor; std::unique_ptr<vespalib::eval::Value> _query_tensor_uptr; const vespalib::eval::Value* _query_tensor; - vespalib::eval::TypedCells _query_tensor_cells; - std::unique_ptr<DistanceFunction> _dist_fun_uptr; - const DistanceFunction* _dist_fun; + std::unique_ptr<BoundDistanceFunction> _dist_fun; public: DistanceCalculator(const tensor::ITensorAttribute& attr_tensor, @@ -35,20 +34,22 @@ public: * Only used by unit tests where ownership of query tensor and distance function is handled outside. */ DistanceCalculator(const tensor::ITensorAttribute& attr_tensor, - const vespalib::eval::Value& query_tensor_in, - const DistanceFunction& function_in); + BoundDistanceFunction::UP function_in); ~DistanceCalculator(); const tensor::ITensorAttribute& attribute_tensor() const { return _attr_tensor; } - const vespalib::eval::Value& query_tensor() const { return *_query_tensor; } - const DistanceFunction& function() const { return *_dist_fun; } + const vespalib::eval::Value& query_tensor() const { + assert(_query_tensor != nullptr); + return *_query_tensor; + } + const BoundDistanceFunction& function() const { return *_dist_fun; } double calc_raw_score(uint32_t docid) const { auto vectors = _attr_tensor.get_vectors(docid); double result = 0.0; for (uint32_t i = 0; i < vectors.subspaces(); ++i) { - double distance = _dist_fun->calc(_query_tensor_cells, vectors.cells(i)); + double distance = _dist_fun->calc(vectors.cells(i)); double score = _dist_fun->to_rawscore(distance); result = std::max(result, score); } @@ -59,7 +60,7 @@ public: auto vectors = _attr_tensor.get_vectors(docid); double result = std::numeric_limits<double>::max(); for (uint32_t i = 0; i < vectors.subspaces(); ++i) { - double distance = _dist_fun->calc_with_limit(_query_tensor_cells, vectors.cells(i), limit); + double distance = _dist_fun->calc_with_limit(vectors.cells(i), limit); result = std::min(result, distance); } return result; @@ -67,7 +68,7 @@ public: void calc_closest_subspace(VectorBundle vectors, std::optional<uint32_t>& closest_subspace, double& best_distance) { for (uint32_t i = 0; i < vectors.subspaces(); ++i) { - double distance = _dist_fun->calc(_query_tensor_cells, vectors.cells(i)); + double distance = _dist_fun->calc(vectors.cells(i)); if (!closest_subspace.has_value() || distance < best_distance) { best_distance = distance; closest_subspace = i; diff --git a/searchlib/src/vespa/searchlib/tensor/distance_function.h b/searchlib/src/vespa/searchlib/tensor/distance_function.h index d5ebf656189..443191a272c 100644 --- a/searchlib/src/vespa/searchlib/tensor/distance_function.h +++ b/searchlib/src/vespa/searchlib/tensor/distance_function.h @@ -9,13 +9,24 @@ namespace vespalib::eval { struct TypedCells; } namespace search::tensor { +class DistanceConverter { +public: + virtual ~DistanceConverter() = default; + + // convert threshold (external distance units) to internal units + virtual double convert_threshold(double threshold) const = 0; + + // convert internal distance to rawscore (1.0 / (1.0 + d)) + virtual double to_rawscore(double distance) const = 0; +}; + /** * Interface used to calculate the distance between two n-dimensional vectors. * * The vectors must be of same size and same cell type (float or double). * The actual implementation must know which type the vectors are. */ -class DistanceFunction { +class DistanceFunction : public DistanceConverter { private: vespalib::eval::CellType _expect_cell_type; public: @@ -33,12 +44,6 @@ public: // calculate internal distance (comparable) virtual double calc(const vespalib::eval::TypedCells& lhs, const vespalib::eval::TypedCells& rhs) const = 0; - // convert threshold (external distance units) to internal units - virtual double convert_threshold(double threshold) const = 0; - - // convert internal distance to rawscore (1.0 / (1.0 + d)) - virtual double to_rawscore(double distance) const = 0; - // calculate internal distance, early return allowed if > limit virtual double calc_with_limit(const vespalib::eval::TypedCells& lhs, const vespalib::eval::TypedCells& rhs, diff --git a/searchlib/src/vespa/searchlib/tensor/distance_function_factory.cpp b/searchlib/src/vespa/searchlib/tensor/distance_function_factory.cpp index 96dfc580d87..f96715bcf60 100644 --- a/searchlib/src/vespa/searchlib/tensor/distance_function_factory.cpp +++ b/searchlib/src/vespa/searchlib/tensor/distance_function_factory.cpp @@ -3,6 +3,8 @@ #include "distance_function_factory.h" #include "distance_functions.h" #include <vespa/vespalib/util/typify.h> +#include <vespa/vespalib/util/array.h> +#include <vespa/vespalib/util/arrayref.h> #include <vespa/log/log.h> LOG_SETUP(".searchlib.tensor.distance_function_factory"); @@ -21,9 +23,9 @@ make_distance_function(DistanceMetric variant, CellType cell_type) switch (cell_type) { case CellType::FLOAT: return std::make_unique<SquaredEuclideanDistanceHW<float>>(); case CellType::DOUBLE: return std::make_unique<SquaredEuclideanDistanceHW<double>>(); - case CellType::INT8: return std::make_unique<SquaredEuclideanDistanceHW<vespalib::eval::Int8Float>>(); + case CellType::INT8: return std::make_unique<SquaredEuclideanDistanceHW<vespalib::eval::Int8Float>>(); default: return std::make_unique<SquaredEuclideanDistance>(CellType::FLOAT); - } + } case DistanceMetric::Angular: switch (cell_type) { case CellType::FLOAT: return std::make_unique<AngularDistanceHW<float>>(); @@ -45,4 +47,54 @@ make_distance_function(DistanceMetric variant, CellType cell_type) return DistanceFunction::UP(); } + +class SimpleBoundDistanceFunction : public BoundDistanceFunction { + const vespalib::eval::TypedCells _lhs; + const DistanceFunction &_df; +public: + SimpleBoundDistanceFunction(const vespalib::eval::TypedCells& lhs, + const DistanceFunction &df) + : BoundDistanceFunction(lhs.type), + _lhs(lhs), + _df(df) + {} + + double calc(const vespalib::eval::TypedCells& rhs) const override { + return _df.calc(_lhs, rhs); + } + double convert_threshold(double threshold) const override { + return _df.convert_threshold(threshold); + } + double to_rawscore(double distance) const override { + return _df.to_rawscore(distance); + } + double calc_with_limit(const vespalib::eval::TypedCells& rhs, double limit) const override { + return _df.calc_with_limit(_lhs, rhs, limit); + } +}; + +class SimpleDistanceFunctionFactory : public DistanceFunctionFactory { + DistanceFunction::UP _df; +public: + SimpleDistanceFunctionFactory(DistanceFunction::UP df) + : DistanceFunctionFactory(df->expected_cell_type()), + _df(std::move(df)) + {} + + BoundDistanceFunction::UP for_query_vector(const vespalib::eval::TypedCells& lhs) override { + return std::make_unique<SimpleBoundDistanceFunction>(lhs, *_df); + } + BoundDistanceFunction::UP for_insertion_vector(const vespalib::eval::TypedCells& lhs) override { + return std::make_unique<SimpleBoundDistanceFunction>(lhs, *_df); + } +}; + +std::unique_ptr<DistanceFunctionFactory> +make_distance_function_factory(search::attribute::DistanceMetric variant, + vespalib::eval::CellType cell_type) +{ + auto df = make_distance_function(variant, cell_type); + return std::make_unique<SimpleDistanceFunctionFactory>(std::move(df)); +} + } diff --git a/searchlib/src/vespa/searchlib/tensor/distance_function_factory.h b/searchlib/src/vespa/searchlib/tensor/distance_function_factory.h index 2d7eb4e73c1..1edb94bd7aa 100644 --- a/searchlib/src/vespa/searchlib/tensor/distance_function_factory.h +++ b/searchlib/src/vespa/searchlib/tensor/distance_function_factory.h @@ -3,12 +3,27 @@ #pragma once #include "distance_function.h" +#include "bound_distance_function.h" #include <vespa/eval/eval/value_type.h> #include <vespa/searchcommon/attribute/distance_metric.h> namespace search::tensor { /** + * API for binding the LHS of a distance calculation + * This allows keeping global state in the factory itself, and state + * for one particular vector in the distance function object. + */ +struct DistanceFunctionFactory { + const vespalib::eval::CellType expected_cell_type; + DistanceFunctionFactory(vespalib::eval::CellType ct) : expected_cell_type(ct) {} + virtual ~DistanceFunctionFactory() {} + virtual BoundDistanceFunction::UP for_query_vector(const vespalib::eval::TypedCells& lhs) = 0; + virtual BoundDistanceFunction::UP for_insertion_vector(const vespalib::eval::TypedCells& lhs) = 0; + using UP = std::unique_ptr<DistanceFunctionFactory>; +}; + +/** * Create a distance function object customized for the given metric * variant and cell type. **/ @@ -16,4 +31,8 @@ DistanceFunction::UP make_distance_function(search::attribute::DistanceMetric variant, vespalib::eval::CellType cell_type); +DistanceFunctionFactory::UP +make_distance_function_factory(search::attribute::DistanceMetric variant, + vespalib::eval::CellType cell_type); + } diff --git a/searchlib/src/vespa/searchlib/tensor/hnsw_index.cpp b/searchlib/src/vespa/searchlib/tensor/hnsw_index.cpp index af332189b61..fa7f150fd89 100644 --- a/searchlib/src/vespa/searchlib/tensor/hnsw_index.cpp +++ b/searchlib/src/vespa/searchlib/tensor/hnsw_index.cpp @@ -278,23 +278,25 @@ double HnswIndex<type>::calc_distance(uint32_t lhs_nodeid, uint32_t rhs_nodeid) const { auto lhs = get_vector(lhs_nodeid); - return calc_distance(lhs, rhs_nodeid); + auto df = _distance_ff->for_insertion_vector(lhs); + auto rhs = get_vector(rhs_nodeid); + return df->calc(rhs); } template <HnswIndexType type> double -HnswIndex<type>::calc_distance(const TypedCells& lhs, uint32_t rhs_nodeid) const +HnswIndex<type>::calc_distance(const BoundDistanceFunction &df, uint32_t rhs_nodeid) const { auto rhs = get_vector(rhs_nodeid); - return _distance_func->calc(lhs, rhs); + return df.calc(rhs); } template <HnswIndexType type> double -HnswIndex<type>::calc_distance(const TypedCells& lhs, uint32_t rhs_docid, uint32_t rhs_subspace) const +HnswIndex<type>::calc_distance(const BoundDistanceFunction &df, uint32_t rhs_docid, uint32_t rhs_subspace) const { auto rhs = get_vector(rhs_docid, rhs_subspace); - return _distance_func->calc(lhs, rhs); + return df.calc(rhs); } template <HnswIndexType type> @@ -323,7 +325,9 @@ HnswIndex<type>::estimate_visited_nodes(uint32_t level, uint32_t nodeid_limit, u template <HnswIndexType type> HnswCandidate -HnswIndex<type>::find_nearest_in_layer(const TypedCells& input, const HnswCandidate& entry_point, uint32_t level) const +HnswIndex<type>::find_nearest_in_layer( + const BoundDistanceFunction &df, + const HnswCandidate& entry_point, uint32_t level) const { HnswCandidate nearest = entry_point; bool keep_searching = true; @@ -334,7 +338,7 @@ HnswIndex<type>::find_nearest_in_layer(const TypedCells& input, const HnswCandid auto neighbor_ref = neighbor_node.levels_ref().load_acquire(); uint32_t neighbor_docid = acquire_docid(neighbor_node, neighbor_nodeid); uint32_t neighbor_subspace = neighbor_node.acquire_subspace(); - double dist = calc_distance(input, neighbor_docid, neighbor_subspace); + double dist = calc_distance(df, neighbor_docid, neighbor_subspace); if (_graph.still_valid(neighbor_nodeid, neighbor_ref) && dist < nearest.distance) { @@ -349,9 +353,11 @@ HnswIndex<type>::find_nearest_in_layer(const TypedCells& input, const HnswCandid template <HnswIndexType type> template <class VisitedTracker, class BestNeighbors> void -HnswIndex<type>::search_layer_helper(const TypedCells& input, uint32_t neighbors_to_find, - BestNeighbors& best_neighbors, uint32_t level, const GlobalFilter *filter, - uint32_t nodeid_limit, uint32_t estimated_visited_nodes) const +HnswIndex<type>::search_layer_helper( + const BoundDistanceFunction &df, + uint32_t neighbors_to_find, + BestNeighbors& best_neighbors, uint32_t level, const GlobalFilter *filter, + uint32_t nodeid_limit, uint32_t estimated_visited_nodes) const { NearestPriQ candidates; GlobalFilterWrapper<type> filter_wrapper(filter); @@ -389,7 +395,7 @@ HnswIndex<type>::search_layer_helper(const TypedCells& input, uint32_t neighbors } uint32_t neighbor_docid = acquire_docid(neighbor_node, neighbor_nodeid); uint32_t neighbor_subspace = neighbor_node.acquire_subspace(); - double dist_to_input = calc_distance(input, neighbor_docid, neighbor_subspace); + double dist_to_input = calc_distance(df, neighbor_docid, neighbor_subspace); if (dist_to_input < limit_dist) { candidates.emplace(neighbor_nodeid, neighbor_ref, dist_to_input); if (filter_wrapper.check(neighbor_docid)) { @@ -407,29 +413,31 @@ HnswIndex<type>::search_layer_helper(const TypedCells& input, uint32_t neighbors template <HnswIndexType type> template <class BestNeighbors> void -HnswIndex<type>::search_layer(const TypedCells& input, uint32_t neighbors_to_find, - BestNeighbors& best_neighbors, uint32_t level, const GlobalFilter *filter) const +HnswIndex<type>::search_layer( + const BoundDistanceFunction &df, + uint32_t neighbors_to_find, + BestNeighbors& best_neighbors, uint32_t level, const GlobalFilter *filter) const { uint32_t nodeid_limit = _graph.nodes_size.load(std::memory_order_acquire); uint32_t estimated_visited_nodes = estimate_visited_nodes(level, nodeid_limit, neighbors_to_find, filter); if (estimated_visited_nodes >= nodeid_limit / 128) { - search_layer_helper<BitVectorVisitedTracker>(input, neighbors_to_find, best_neighbors, level, filter, nodeid_limit, estimated_visited_nodes); + search_layer_helper<BitVectorVisitedTracker>(df, neighbors_to_find, best_neighbors, level, filter, nodeid_limit, estimated_visited_nodes); } else { - search_layer_helper<HashSetVisitedTracker>(input, neighbors_to_find, best_neighbors, level, filter, nodeid_limit, estimated_visited_nodes); + search_layer_helper<HashSetVisitedTracker>(df, neighbors_to_find, best_neighbors, level, filter, nodeid_limit, estimated_visited_nodes); } } template <HnswIndexType type> -HnswIndex<type>::HnswIndex(const DocVectorAccess& vectors, DistanceFunction::UP distance_func, +HnswIndex<type>::HnswIndex(const DocVectorAccess& vectors, DistanceFunctionFactory::UP distance_ff, RandomLevelGenerator::UP level_generator, const HnswIndexConfig& cfg) : _graph(), _vectors(vectors), - _distance_func(std::move(distance_func)), + _distance_ff(std::move(distance_ff)), _level_generator(std::move(level_generator)), _id_mapping(), _cfg(cfg) { - assert(_distance_func); + assert(_distance_ff); } template <HnswIndexType type> @@ -483,12 +491,13 @@ HnswIndex<type>::internal_prepare_add_node(PreparedAddDoc& op, TypedCells input_ return; } int search_level = entry.level; - double entry_dist = calc_distance(input_vector, entry.nodeid); + auto df = _distance_ff->for_insertion_vector(input_vector); + double entry_dist = calc_distance(*df, entry.nodeid); uint32_t entry_docid = get_docid(entry.nodeid); // TODO: check if entry nodeid/levels_ref is still valid here HnswCandidate entry_point(entry.nodeid, entry_docid, entry.levels_ref, entry_dist); while (search_level > node_max_level) { - entry_point = find_nearest_in_layer(input_vector, entry_point, search_level); + entry_point = find_nearest_in_layer(*df, entry_point, search_level); --search_level; } @@ -497,7 +506,7 @@ HnswIndex<type>::internal_prepare_add_node(PreparedAddDoc& op, TypedCells input_ search_level = std::min(node_max_level, search_level); // Find neighbors of the added document in each level it should exist in. while (search_level >= 0) { - search_layer(input_vector, _cfg.neighbors_to_explore_at_construction(), best_neighbors, search_level); + search_layer(*df, _cfg.neighbors_to_explore_at_construction(), best_neighbors, search_level); auto neighbors = select_neighbors(best_neighbors.peek(), _cfg.max_links_on_inserts()); auto& links = connections[search_level]; links.reserve(neighbors.used.size()); @@ -850,11 +859,13 @@ struct NeighborsByDocId { template <HnswIndexType type> std::vector<NearestNeighborIndex::Neighbor> -HnswIndex<type>::top_k_by_docid(uint32_t k, TypedCells vector, - const GlobalFilter *filter, uint32_t explore_k, - double distance_threshold) const +HnswIndex<type>::top_k_by_docid( + uint32_t k, + const BoundDistanceFunction &df, + const GlobalFilter *filter, uint32_t explore_k, + double distance_threshold) const { - SearchBestNeighbors candidates = top_k_candidates(vector, std::max(k, explore_k), filter); + SearchBestNeighbors candidates = top_k_candidates(df, std::max(k, explore_k), filter); auto result = candidates.get_neighbors(k, distance_threshold); std::sort(result.begin(), result.end(), NeighborsByDocId()); return result; @@ -862,24 +873,31 @@ HnswIndex<type>::top_k_by_docid(uint32_t k, TypedCells vector, template <HnswIndexType type> std::vector<NearestNeighborIndex::Neighbor> -HnswIndex<type>::find_top_k(uint32_t k, TypedCells vector, uint32_t explore_k, - double distance_threshold) const +HnswIndex<type>::find_top_k( + uint32_t k, + const BoundDistanceFunction &df, + uint32_t explore_k, + double distance_threshold) const { - return top_k_by_docid(k, vector, nullptr, explore_k, distance_threshold); + return top_k_by_docid(k, df, nullptr, explore_k, distance_threshold); } template <HnswIndexType type> std::vector<NearestNeighborIndex::Neighbor> -HnswIndex<type>::find_top_k_with_filter(uint32_t k, TypedCells vector, - const GlobalFilter &filter, uint32_t explore_k, - double distance_threshold) const +HnswIndex<type>::find_top_k_with_filter( + uint32_t k, + const BoundDistanceFunction &df, + const GlobalFilter &filter, uint32_t explore_k, + double distance_threshold) const { - return top_k_by_docid(k, vector, &filter, explore_k, distance_threshold); + return top_k_by_docid(k, df, &filter, explore_k, distance_threshold); } template <HnswIndexType type> typename HnswIndex<type>::SearchBestNeighbors -HnswIndex<type>::top_k_candidates(const TypedCells &vector, uint32_t k, const GlobalFilter *filter) const +HnswIndex<type>::top_k_candidates( + const BoundDistanceFunction &df, + uint32_t k, const GlobalFilter *filter) const { SearchBestNeighbors best_neighbors; auto entry = _graph.get_entry_node(); @@ -888,16 +906,16 @@ HnswIndex<type>::top_k_candidates(const TypedCells &vector, uint32_t k, const Gl return best_neighbors; } int search_level = entry.level; - double entry_dist = calc_distance(vector, entry.nodeid); + double entry_dist = calc_distance(df, entry.nodeid); uint32_t entry_docid = get_docid(entry.nodeid); // TODO: check if entry docid/levels_ref is still valid here HnswCandidate entry_point(entry.nodeid, entry_docid, entry.levels_ref, entry_dist); while (search_level > 0) { - entry_point = find_nearest_in_layer(vector, entry_point, search_level); + entry_point = find_nearest_in_layer(df, entry_point, search_level); --search_level; } best_neighbors.push(entry_point); - search_layer(vector, k, best_neighbors, 0, filter); + search_layer(df, k, best_neighbors, 0, filter); return best_neighbors; } diff --git a/searchlib/src/vespa/searchlib/tensor/hnsw_index.h b/searchlib/src/vespa/searchlib/tensor/hnsw_index.h index 984acc6c9a1..0809dcf4fe3 100644 --- a/searchlib/src/vespa/searchlib/tensor/hnsw_index.h +++ b/searchlib/src/vespa/searchlib/tensor/hnsw_index.h @@ -4,6 +4,7 @@ #include "hnsw_index_config.h" #include "distance_function.h" +#include "distance_function_factory.h" #include "doc_vector_access.h" #include "hnsw_identity_mapping.h" #include "hnsw_index_utils.h" @@ -104,7 +105,7 @@ protected: GraphType _graph; const DocVectorAccess& _vectors; - DistanceFunction::UP _distance_func; + std::unique_ptr<DistanceFunctionFactory> _distance_ff; RandomLevelGenerator::UP _level_generator; IdMapping _id_mapping; // mapping from docid to nodeid vector HnswIndexConfig _cfg; @@ -158,23 +159,23 @@ protected: } double calc_distance(uint32_t lhs_nodeid, uint32_t rhs_nodeid) const; - double calc_distance(const TypedCells& lhs, uint32_t rhs_nodeid) const; - double calc_distance(const TypedCells& lhs, uint32_t rhs_docid, uint32_t rhs_subspace) const; + double calc_distance(const BoundDistanceFunction &df, uint32_t rhs_nodeid) const; + double calc_distance(const BoundDistanceFunction &df, uint32_t rhs_docid, uint32_t rhs_subspace) const; uint32_t estimate_visited_nodes(uint32_t level, uint32_t nodeid_limit, uint32_t neighbors_to_find, const GlobalFilter* filter) const; /** * Performs a greedy search in the given layer to find the candidate that is nearest the input vector. */ - HnswCandidate find_nearest_in_layer(const TypedCells& input, const HnswCandidate& entry_point, uint32_t level) const; + HnswCandidate find_nearest_in_layer(const BoundDistanceFunction &df, const HnswCandidate& entry_point, uint32_t level) const; template <class VisitedTracker, class BestNeighbors> - void search_layer_helper(const TypedCells& input, uint32_t neighbors_to_find, BestNeighbors& best_neighbors, + void search_layer_helper(const BoundDistanceFunction &df, uint32_t neighbors_to_find, BestNeighbors& best_neighbors, uint32_t level, const GlobalFilter *filter, uint32_t nodeid_limit, uint32_t estimated_visited_nodes) const; template <class BestNeighbors> - void search_layer(const TypedCells& input, uint32_t neighbors_to_find, BestNeighbors& best_neighbors, + void search_layer(const BoundDistanceFunction &df, uint32_t neighbors_to_find, BestNeighbors& best_neighbors, uint32_t level, const GlobalFilter *filter = nullptr) const; - std::vector<Neighbor> top_k_by_docid(uint32_t k, TypedCells vector, + std::vector<Neighbor> top_k_by_docid(uint32_t k, const BoundDistanceFunction &df, const GlobalFilter *filter, uint32_t explore_k, double distance_threshold) const; @@ -185,7 +186,7 @@ protected: void internal_complete_add(uint32_t docid, internal::PreparedAddDoc &op); void internal_complete_add_node(uint32_t nodeid, uint32_t docid, uint32_t subspace, internal::PreparedAddNode &prepared_node); public: - HnswIndex(const DocVectorAccess& vectors, DistanceFunction::UP distance_func, + HnswIndex(const DocVectorAccess& vectors, DistanceFunctionFactory::UP distance_ff, RandomLevelGenerator::UP level_generator, const HnswIndexConfig& cfg); ~HnswIndex() override; @@ -213,14 +214,23 @@ public: std::unique_ptr<NearestNeighborIndexSaver> make_saver() const override; std::unique_ptr<NearestNeighborIndexLoader> make_loader(FastOS_FileInterface& file) override; - std::vector<Neighbor> find_top_k(uint32_t k, TypedCells vector, uint32_t explore_k, - double distance_threshold) const override; - std::vector<Neighbor> find_top_k_with_filter(uint32_t k, TypedCells vector, - const GlobalFilter &filter, uint32_t explore_k, - double distance_threshold) const override; - const DistanceFunction *distance_function() const override { return _distance_func.get(); } + std::vector<Neighbor> find_top_k( + uint32_t k, + const BoundDistanceFunction &df, + uint32_t explore_k, + double distance_threshold) const override; - SearchBestNeighbors top_k_candidates(const TypedCells &vector, uint32_t k, const GlobalFilter *filter) const; + std::vector<Neighbor> find_top_k_with_filter( + uint32_t k, + const BoundDistanceFunction &df, + const GlobalFilter &filter, uint32_t explore_k, + double distance_threshold) const override; + + DistanceFunctionFactory &distance_function_factory() const override { return *_distance_ff; } + + SearchBestNeighbors top_k_candidates( + const BoundDistanceFunction &df, + uint32_t k, const GlobalFilter *filter) const; uint32_t get_entry_nodeid() const { return _graph.get_entry_node().nodeid; } int32_t get_entry_level() const { return _graph.get_entry_node().level; } diff --git a/searchlib/src/vespa/searchlib/tensor/i_tensor_attribute.h b/searchlib/src/vespa/searchlib/tensor/i_tensor_attribute.h index ec6774c9517..b734663a6f4 100644 --- a/searchlib/src/vespa/searchlib/tensor/i_tensor_attribute.h +++ b/searchlib/src/vespa/searchlib/tensor/i_tensor_attribute.h @@ -12,6 +12,7 @@ namespace vespalib::slime { struct Inserter; } namespace search::tensor { +struct DistanceFunctionFactory; class NearestNeighborIndex; class SerializedTensorRef; @@ -32,6 +33,7 @@ public: virtual const vespalib::eval::ValueType & getTensorType() const = 0; + virtual DistanceFunctionFactory& distance_function_factory() const = 0; virtual const NearestNeighborIndex* nearest_neighbor_index() const { return nullptr; } using DistanceMetric = search::attribute::DistanceMetric; virtual DistanceMetric distance_metric() const = 0; diff --git a/searchlib/src/vespa/searchlib/tensor/imported_tensor_attribute_vector_read_guard.h b/searchlib/src/vespa/searchlib/tensor/imported_tensor_attribute_vector_read_guard.h index 4e1cc9efd96..0fb0fd1bf78 100644 --- a/searchlib/src/vespa/searchlib/tensor/imported_tensor_attribute_vector_read_guard.h +++ b/searchlib/src/vespa/searchlib/tensor/imported_tensor_attribute_vector_read_guard.h @@ -38,6 +38,9 @@ public: SerializedTensorRef get_serialized_tensor_ref(uint32_t docid) const override; bool supports_extract_cells_ref() const override { return _target_tensor_attribute.supports_extract_cells_ref(); } bool supports_get_tensor_ref() const override { return _target_tensor_attribute.supports_get_tensor_ref(); } + DistanceFunctionFactory& distance_function_factory() const override { + return _target_tensor_attribute.distance_function_factory(); + } DistanceMetric distance_metric() const override { return _target_tensor_attribute.distance_metric(); } bool supports_get_serialized_tensor_ref() const override; uint32_t get_num_docs() const override { return getNumDocs(); } diff --git a/searchlib/src/vespa/searchlib/tensor/nearest_neighbor_index.h b/searchlib/src/vespa/searchlib/tensor/nearest_neighbor_index.h index 3f6c9b82a65..4b7b934fee0 100644 --- a/searchlib/src/vespa/searchlib/tensor/nearest_neighbor_index.h +++ b/searchlib/src/vespa/searchlib/tensor/nearest_neighbor_index.h @@ -3,6 +3,7 @@ #pragma once #include "distance_function.h" +#include "distance_function_factory.h" #include "prepare_result.h" #include "vector_bundle.h" #include <vespa/vespalib/util/generationhandler.h> @@ -97,18 +98,18 @@ public: virtual std::unique_ptr<NearestNeighborIndexLoader> make_loader(FastOS_FileInterface& file) = 0; virtual std::vector<Neighbor> find_top_k(uint32_t k, - vespalib::eval::TypedCells vector, + const BoundDistanceFunction &df, uint32_t explore_k, double distance_threshold) const = 0; // only return neighbors where the corresponding filter bit is set virtual std::vector<Neighbor> find_top_k_with_filter(uint32_t k, - vespalib::eval::TypedCells vector, + const BoundDistanceFunction &df, const GlobalFilter &filter, uint32_t explore_k, double distance_threshold) const = 0; - virtual const DistanceFunction *distance_function() const = 0; + virtual DistanceFunctionFactory &distance_function_factory() const = 0; }; } diff --git a/searchlib/src/vespa/searchlib/tensor/tensor_attribute.cpp b/searchlib/src/vespa/searchlib/tensor/tensor_attribute.cpp index 1e388199ef8..5e554f76779 100644 --- a/searchlib/src/vespa/searchlib/tensor/tensor_attribute.cpp +++ b/searchlib/src/vespa/searchlib/tensor/tensor_attribute.cpp @@ -58,6 +58,7 @@ TensorAttribute::TensorAttribute(vespalib::stringref name, const Config &cfg, Te : NotImplementedAttribute(name, cfg), _refVector(cfg.getGrowStrategy(), getGenerationHolder()), _tensorStore(tensorStore), + _distance_function_factory(make_distance_function_factory(cfg.distance_metric(), cfg.tensorType().cell_type())), _index(), _is_dense(cfg.tensorType().is_dense()), _emptyTensor(createEmptyTensor(cfg.tensorType())), @@ -280,6 +281,13 @@ TensorAttribute::getTensorType() const return getConfig().tensorType(); } +DistanceFunctionFactory& +TensorAttribute::distance_function_factory() const +{ + return *_distance_function_factory; + +} + const NearestNeighborIndex* TensorAttribute::nearest_neighbor_index() const { diff --git a/searchlib/src/vespa/searchlib/tensor/tensor_attribute.h b/searchlib/src/vespa/searchlib/tensor/tensor_attribute.h index 4cb903c6c67..f629562a34d 100644 --- a/searchlib/src/vespa/searchlib/tensor/tensor_attribute.h +++ b/searchlib/src/vespa/searchlib/tensor/tensor_attribute.h @@ -28,6 +28,7 @@ protected: RefVector _refVector; // docId -> ref in data store for serialized tensor TensorStore &_tensorStore; // data store for serialized tensors + std::unique_ptr<DistanceFunctionFactory> _distance_function_factory; std::unique_ptr<NearestNeighborIndex> _index; bool _is_dense; std::unique_ptr<vespalib::eval::Value> _emptyTensor; @@ -67,6 +68,7 @@ public: bool supports_get_tensor_ref() const override { return false; } bool supports_get_serialized_tensor_ref() const override; const vespalib::eval::ValueType & getTensorType() const override; + DistanceFunctionFactory& distance_function_factory() const override; const NearestNeighborIndex* nearest_neighbor_index() const override; void get_state(const vespalib::slime::Inserter& inserter) const override; void clearDocs(DocId lidLow, DocId lidLimit, bool in_shrink_lid_space) override; diff --git a/searchlib/src/vespa/searchlib/tensor/tensor_ext_attribute.cpp b/searchlib/src/vespa/searchlib/tensor/tensor_ext_attribute.cpp index 19c8cf6053b..f474d65a19d 100644 --- a/searchlib/src/vespa/searchlib/tensor/tensor_ext_attribute.cpp +++ b/searchlib/src/vespa/searchlib/tensor/tensor_ext_attribute.cpp @@ -36,6 +36,7 @@ TensorExtAttribute::TensorExtAttribute(const vespalib::string& name, const Confi : NotImplementedAttribute(name, cfg), ITensorAttribute(), IExtendAttribute(), + _distance_function_factory(make_distance_function_factory(cfg.distance_metric(), cfg.tensorType().cell_type())), _subspace_type(cfg.tensorType()), _empty(_subspace_type), _empty_tensor(create_empty_tensor(cfg.tensorType())) diff --git a/searchlib/src/vespa/searchlib/tensor/tensor_ext_attribute.h b/searchlib/src/vespa/searchlib/tensor/tensor_ext_attribute.h index a58426cd146..93d7a94c257 100644 --- a/searchlib/src/vespa/searchlib/tensor/tensor_ext_attribute.h +++ b/searchlib/src/vespa/searchlib/tensor/tensor_ext_attribute.h @@ -5,6 +5,7 @@ #include "i_tensor_attribute.h" #include "empty_subspace.h" #include "subspace_type.h" +#include "distance_function_factory.h" #include <vespa/searchlib/attribute/not_implemented_attribute.h> #include <vespa/vespalib/stllike/allocator.h> @@ -20,9 +21,12 @@ class TensorExtAttribute : public NotImplementedAttribute, public IExtendAttribute { std::vector<const vespalib::eval::Value*> _data; + // XXX this should probably be longer-lived: + std::unique_ptr<DistanceFunctionFactory> _distance_function_factory; SubspaceType _subspace_type; EmptySubspace _empty; std::unique_ptr<vespalib::eval::Value> _empty_tensor; + public: TensorExtAttribute(const vespalib::string& name, const Config& cfg); ~TensorExtAttribute() override; @@ -46,6 +50,9 @@ public: bool supports_get_tensor_ref() const override; bool supports_get_serialized_tensor_ref() const override; const vespalib::eval::ValueType & getTensorType() const override; + DistanceFunctionFactory& distance_function_factory() const override { + return *_distance_function_factory; + } search::attribute::DistanceMetric distance_metric() const override; uint32_t get_num_docs() const override; void get_state(const vespalib::slime::Inserter& inserter) const override; |