diff options
author | Arne Juul <arnej@yahooinc.com> | 2023-04-19 20:46:33 +0000 |
---|---|---|
committer | Arne Juul <arnej@yahooinc.com> | 2023-04-19 20:46:38 +0000 |
commit | 732e4c4be8bbc5a43e3adae5db222301e630bd8c (patch) | |
tree | 37926eda6dfa9bcbc87f4c96f74bf487ddc53ffe /searchlib | |
parent | 3880d66a21f151e97ac6fb892aa56909591e830e (diff) |
follow-up after review
* add class comment on API declaration
* prefer snake_case for methods
* prefer reference
Diffstat (limited to 'searchlib')
6 files changed, 26 insertions, 21 deletions
diff --git a/searchlib/src/tests/queryeval/nearest_neighbor/nearest_neighbor_test.cpp b/searchlib/src/tests/queryeval/nearest_neighbor/nearest_neighbor_test.cpp index c687a46186a..fd07529795a 100644 --- a/searchlib/src/tests/queryeval/nearest_neighbor/nearest_neighbor_test.cpp +++ b/searchlib/src/tests/queryeval/nearest_neighbor/nearest_neighbor_test.cpp @@ -136,7 +136,7 @@ SimpleResult find_matches(Fixture &env, const Value &qtv, double threshold = std auto &attr = *(env._attr); auto dff = search::tensor::make_distance_function_factory(DistanceMetric::Euclidean, qtv.cells().type); - DistanceCalculator dist_calc(attr, dff->forQueryVector(qtv.cells())); + DistanceCalculator dist_calc(attr, dff->for_query_vector(qtv.cells())); NearestNeighborDistanceHeap dh(2); dh.set_distance_threshold(env.dist_fun().convert_threshold(threshold)); const GlobalFilter &filter = *env._global_filter; @@ -263,7 +263,7 @@ std::vector<feature_t> get_rawscores(Fixture &env, const Value &qtv) { auto &tfmd = *(md->resolveTermField(0)); auto &attr = *(env._attr); auto dff = search::tensor::make_distance_function_factory(DistanceMetric::Euclidean, qtv.cells().type); - DistanceCalculator dist_calc(attr, dff->forQueryVector(qtv.cells())); + DistanceCalculator dist_calc(attr, dff->for_query_vector(qtv.cells())); NearestNeighborDistanceHeap dh(2); auto dummy_filter = GlobalFilter::create(); auto search = NearestNeighborIterator::create(strict, tfmd, dist_calc, dh, *dummy_filter); @@ -338,7 +338,7 @@ TEST(NnsIndexIteratorTest, require_that_iterator_works_as_expected) { auto &tfmd = *(md->resolveTermField(0)); auto dff = search::tensor::make_distance_function_factory(DistanceMetric::Euclidean, CellType::DOUBLE); vespalib::eval::TypedCells dummy; - auto df = dff->forQueryVector(dummy); + auto df = dff->for_query_vector(dummy); auto search = NnsIndexIterator::create(tfmd, hits, *df); search->initFullRange(); expect_not_match(*search, 1, 2); diff --git a/searchlib/src/tests/tensor/hnsw_index/hnsw_index_test.cpp b/searchlib/src/tests/tensor/hnsw_index/hnsw_index_test.cpp index 768157412f9..9f6216f5867 100644 --- a/searchlib/src/tests/tensor/hnsw_index/hnsw_index_test.cpp +++ b/searchlib/src/tests/tensor/hnsw_index/hnsw_index_test.cpp @@ -171,7 +171,7 @@ public: uint32_t explore_k = 100; vespalib::ArrayRef qv_ref(qv); vespalib::eval::TypedCells qv_cells(qv_ref); - auto df = index->distance_function_factory().forQueryVector(qv_cells); + auto df = index->distance_function_factory().for_query_vector(qv_cells); auto got_by_docid = (global_filter->is_active()) ? index->find_top_k_with_filter(k, *df, *global_filter, explore_k, 10000.0) : index->find_top_k(k, *df, explore_k, 10000.0); @@ -185,7 +185,7 @@ public: void expect_top_3(uint32_t docid, std::vector<uint32_t> exp_hits) { uint32_t k = 3; auto qv = vectors.get_vector(docid, 0); - auto df = index->distance_function_factory().forQueryVector(qv); + auto df = index->distance_function_factory().for_query_vector(qv); auto rv = index->top_k_candidates(*df, k, global_filter->ptr_if_active()).peek(); std::sort(rv.begin(), rv.end(), LesserDistance()); size_t idx = 0; @@ -206,7 +206,7 @@ public: } void check_with_distance_threshold(uint32_t docid) { auto qv = vectors.get_vector(docid, 0); - auto df = index->distance_function_factory().forQueryVector(qv); + auto df = index->distance_function_factory().for_query_vector(qv); uint32_t k = 3; auto rv = index->top_k_candidates(*df, k, global_filter->ptr_if_active()).peek(); std::sort(rv.begin(), rv.end(), LesserDistance()); diff --git a/searchlib/src/vespa/searchlib/tensor/distance_calculator.cpp b/searchlib/src/vespa/searchlib/tensor/distance_calculator.cpp index 9a2287af074..8da777d97eb 100644 --- a/searchlib/src/vespa/searchlib/tensor/distance_calculator.cpp +++ b/searchlib/src/vespa/searchlib/tensor/distance_calculator.cpp @@ -67,7 +67,7 @@ DistanceCalculator::DistanceCalculator(const tensor::ITensorAttribute& attr_tens _query_tensor_uptr = converter(query_ct, required_ct, *_query_tensor); _query_tensor = _query_tensor_uptr.get(); } - _dist_fun = dff.forQueryVector(_query_tensor->cells()); + _dist_fun = dff.for_query_vector(_query_tensor->cells()); assert(_dist_fun); } diff --git a/searchlib/src/vespa/searchlib/tensor/distance_function_factory.cpp b/searchlib/src/vespa/searchlib/tensor/distance_function_factory.cpp index 0433e4824aa..f96715bcf60 100644 --- a/searchlib/src/vespa/searchlib/tensor/distance_function_factory.cpp +++ b/searchlib/src/vespa/searchlib/tensor/distance_function_factory.cpp @@ -50,26 +50,26 @@ make_distance_function(DistanceMetric variant, CellType cell_type) class SimpleBoundDistanceFunction : public BoundDistanceFunction { const vespalib::eval::TypedCells _lhs; - const DistanceFunction *_df; + const DistanceFunction &_df; public: SimpleBoundDistanceFunction(const vespalib::eval::TypedCells& lhs, - const DistanceFunction *df) + const DistanceFunction &df) : BoundDistanceFunction(lhs.type), _lhs(lhs), _df(df) {} double calc(const vespalib::eval::TypedCells& rhs) const override { - return _df->calc(_lhs, rhs); + return _df.calc(_lhs, rhs); } double convert_threshold(double threshold) const override { - return _df->convert_threshold(threshold); + return _df.convert_threshold(threshold); } double to_rawscore(double distance) const override { - return _df->to_rawscore(distance); + return _df.to_rawscore(distance); } double calc_with_limit(const vespalib::eval::TypedCells& rhs, double limit) const override { - return _df->calc_with_limit(_lhs, rhs, limit); + return _df.calc_with_limit(_lhs, rhs, limit); } }; @@ -81,11 +81,11 @@ public: _df(std::move(df)) {} - BoundDistanceFunction::UP forQueryVector(const vespalib::eval::TypedCells& lhs) override { - return std::make_unique<SimpleBoundDistanceFunction>(lhs, _df.get()); + BoundDistanceFunction::UP for_query_vector(const vespalib::eval::TypedCells& lhs) override { + return std::make_unique<SimpleBoundDistanceFunction>(lhs, *_df); } - BoundDistanceFunction::UP forInsertionVector(const vespalib::eval::TypedCells& lhs) override { - return std::make_unique<SimpleBoundDistanceFunction>(lhs, _df.get()); + BoundDistanceFunction::UP for_insertion_vector(const vespalib::eval::TypedCells& lhs) override { + return std::make_unique<SimpleBoundDistanceFunction>(lhs, *_df); } }; diff --git a/searchlib/src/vespa/searchlib/tensor/distance_function_factory.h b/searchlib/src/vespa/searchlib/tensor/distance_function_factory.h index 76ed0e59358..1edb94bd7aa 100644 --- a/searchlib/src/vespa/searchlib/tensor/distance_function_factory.h +++ b/searchlib/src/vespa/searchlib/tensor/distance_function_factory.h @@ -9,12 +9,17 @@ namespace search::tensor { +/** + * API for binding the LHS of a distance calculation + * This allows keeping global state in the factory itself, and state + * for one particular vector in the distance function object. + */ struct DistanceFunctionFactory { const vespalib::eval::CellType expected_cell_type; DistanceFunctionFactory(vespalib::eval::CellType ct) : expected_cell_type(ct) {} virtual ~DistanceFunctionFactory() {} - virtual BoundDistanceFunction::UP forQueryVector(const vespalib::eval::TypedCells& lhs) = 0; - virtual BoundDistanceFunction::UP forInsertionVector(const vespalib::eval::TypedCells& lhs) = 0; + virtual BoundDistanceFunction::UP for_query_vector(const vespalib::eval::TypedCells& lhs) = 0; + virtual BoundDistanceFunction::UP for_insertion_vector(const vespalib::eval::TypedCells& lhs) = 0; using UP = std::unique_ptr<DistanceFunctionFactory>; }; diff --git a/searchlib/src/vespa/searchlib/tensor/hnsw_index.cpp b/searchlib/src/vespa/searchlib/tensor/hnsw_index.cpp index c879aa13571..fa7f150fd89 100644 --- a/searchlib/src/vespa/searchlib/tensor/hnsw_index.cpp +++ b/searchlib/src/vespa/searchlib/tensor/hnsw_index.cpp @@ -278,7 +278,7 @@ double HnswIndex<type>::calc_distance(uint32_t lhs_nodeid, uint32_t rhs_nodeid) const { auto lhs = get_vector(lhs_nodeid); - auto df = _distance_ff->forInsertionVector(lhs); + auto df = _distance_ff->for_insertion_vector(lhs); auto rhs = get_vector(rhs_nodeid); return df->calc(rhs); } @@ -491,7 +491,7 @@ HnswIndex<type>::internal_prepare_add_node(PreparedAddDoc& op, TypedCells input_ return; } int search_level = entry.level; - auto df = _distance_ff->forInsertionVector(input_vector); + auto df = _distance_ff->for_insertion_vector(input_vector); double entry_dist = calc_distance(*df, entry.nodeid); uint32_t entry_docid = get_docid(entry.nodeid); // TODO: check if entry nodeid/levels_ref is still valid here |