summaryrefslogtreecommitdiffstats
path: root/searchlib
diff options
context:
space:
mode:
authorArne Juul <arnej@yahooinc.com>2023-04-19 20:46:33 +0000
committerArne Juul <arnej@yahooinc.com>2023-04-19 20:46:38 +0000
commit732e4c4be8bbc5a43e3adae5db222301e630bd8c (patch)
tree37926eda6dfa9bcbc87f4c96f74bf487ddc53ffe /searchlib
parent3880d66a21f151e97ac6fb892aa56909591e830e (diff)
follow-up after review
* add class comment on API declaration * prefer snake_case for methods * prefer reference
Diffstat (limited to 'searchlib')
-rw-r--r--searchlib/src/tests/queryeval/nearest_neighbor/nearest_neighbor_test.cpp6
-rw-r--r--searchlib/src/tests/tensor/hnsw_index/hnsw_index_test.cpp6
-rw-r--r--searchlib/src/vespa/searchlib/tensor/distance_calculator.cpp2
-rw-r--r--searchlib/src/vespa/searchlib/tensor/distance_function_factory.cpp20
-rw-r--r--searchlib/src/vespa/searchlib/tensor/distance_function_factory.h9
-rw-r--r--searchlib/src/vespa/searchlib/tensor/hnsw_index.cpp4
6 files changed, 26 insertions, 21 deletions
diff --git a/searchlib/src/tests/queryeval/nearest_neighbor/nearest_neighbor_test.cpp b/searchlib/src/tests/queryeval/nearest_neighbor/nearest_neighbor_test.cpp
index c687a46186a..fd07529795a 100644
--- a/searchlib/src/tests/queryeval/nearest_neighbor/nearest_neighbor_test.cpp
+++ b/searchlib/src/tests/queryeval/nearest_neighbor/nearest_neighbor_test.cpp
@@ -136,7 +136,7 @@ SimpleResult find_matches(Fixture &env, const Value &qtv, double threshold = std
auto &attr = *(env._attr);
auto dff = search::tensor::make_distance_function_factory(DistanceMetric::Euclidean, qtv.cells().type);
- DistanceCalculator dist_calc(attr, dff->forQueryVector(qtv.cells()));
+ DistanceCalculator dist_calc(attr, dff->for_query_vector(qtv.cells()));
NearestNeighborDistanceHeap dh(2);
dh.set_distance_threshold(env.dist_fun().convert_threshold(threshold));
const GlobalFilter &filter = *env._global_filter;
@@ -263,7 +263,7 @@ std::vector<feature_t> get_rawscores(Fixture &env, const Value &qtv) {
auto &tfmd = *(md->resolveTermField(0));
auto &attr = *(env._attr);
auto dff = search::tensor::make_distance_function_factory(DistanceMetric::Euclidean, qtv.cells().type);
- DistanceCalculator dist_calc(attr, dff->forQueryVector(qtv.cells()));
+ DistanceCalculator dist_calc(attr, dff->for_query_vector(qtv.cells()));
NearestNeighborDistanceHeap dh(2);
auto dummy_filter = GlobalFilter::create();
auto search = NearestNeighborIterator::create(strict, tfmd, dist_calc, dh, *dummy_filter);
@@ -338,7 +338,7 @@ TEST(NnsIndexIteratorTest, require_that_iterator_works_as_expected) {
auto &tfmd = *(md->resolveTermField(0));
auto dff = search::tensor::make_distance_function_factory(DistanceMetric::Euclidean, CellType::DOUBLE);
vespalib::eval::TypedCells dummy;
- auto df = dff->forQueryVector(dummy);
+ auto df = dff->for_query_vector(dummy);
auto search = NnsIndexIterator::create(tfmd, hits, *df);
search->initFullRange();
expect_not_match(*search, 1, 2);
diff --git a/searchlib/src/tests/tensor/hnsw_index/hnsw_index_test.cpp b/searchlib/src/tests/tensor/hnsw_index/hnsw_index_test.cpp
index 768157412f9..9f6216f5867 100644
--- a/searchlib/src/tests/tensor/hnsw_index/hnsw_index_test.cpp
+++ b/searchlib/src/tests/tensor/hnsw_index/hnsw_index_test.cpp
@@ -171,7 +171,7 @@ public:
uint32_t explore_k = 100;
vespalib::ArrayRef qv_ref(qv);
vespalib::eval::TypedCells qv_cells(qv_ref);
- auto df = index->distance_function_factory().forQueryVector(qv_cells);
+ auto df = index->distance_function_factory().for_query_vector(qv_cells);
auto got_by_docid = (global_filter->is_active()) ?
index->find_top_k_with_filter(k, *df, *global_filter, explore_k, 10000.0) :
index->find_top_k(k, *df, explore_k, 10000.0);
@@ -185,7 +185,7 @@ public:
void expect_top_3(uint32_t docid, std::vector<uint32_t> exp_hits) {
uint32_t k = 3;
auto qv = vectors.get_vector(docid, 0);
- auto df = index->distance_function_factory().forQueryVector(qv);
+ auto df = index->distance_function_factory().for_query_vector(qv);
auto rv = index->top_k_candidates(*df, k, global_filter->ptr_if_active()).peek();
std::sort(rv.begin(), rv.end(), LesserDistance());
size_t idx = 0;
@@ -206,7 +206,7 @@ public:
}
void check_with_distance_threshold(uint32_t docid) {
auto qv = vectors.get_vector(docid, 0);
- auto df = index->distance_function_factory().forQueryVector(qv);
+ auto df = index->distance_function_factory().for_query_vector(qv);
uint32_t k = 3;
auto rv = index->top_k_candidates(*df, k, global_filter->ptr_if_active()).peek();
std::sort(rv.begin(), rv.end(), LesserDistance());
diff --git a/searchlib/src/vespa/searchlib/tensor/distance_calculator.cpp b/searchlib/src/vespa/searchlib/tensor/distance_calculator.cpp
index 9a2287af074..8da777d97eb 100644
--- a/searchlib/src/vespa/searchlib/tensor/distance_calculator.cpp
+++ b/searchlib/src/vespa/searchlib/tensor/distance_calculator.cpp
@@ -67,7 +67,7 @@ DistanceCalculator::DistanceCalculator(const tensor::ITensorAttribute& attr_tens
_query_tensor_uptr = converter(query_ct, required_ct, *_query_tensor);
_query_tensor = _query_tensor_uptr.get();
}
- _dist_fun = dff.forQueryVector(_query_tensor->cells());
+ _dist_fun = dff.for_query_vector(_query_tensor->cells());
assert(_dist_fun);
}
diff --git a/searchlib/src/vespa/searchlib/tensor/distance_function_factory.cpp b/searchlib/src/vespa/searchlib/tensor/distance_function_factory.cpp
index 0433e4824aa..f96715bcf60 100644
--- a/searchlib/src/vespa/searchlib/tensor/distance_function_factory.cpp
+++ b/searchlib/src/vespa/searchlib/tensor/distance_function_factory.cpp
@@ -50,26 +50,26 @@ make_distance_function(DistanceMetric variant, CellType cell_type)
class SimpleBoundDistanceFunction : public BoundDistanceFunction {
const vespalib::eval::TypedCells _lhs;
- const DistanceFunction *_df;
+ const DistanceFunction &_df;
public:
SimpleBoundDistanceFunction(const vespalib::eval::TypedCells& lhs,
- const DistanceFunction *df)
+ const DistanceFunction &df)
: BoundDistanceFunction(lhs.type),
_lhs(lhs),
_df(df)
{}
double calc(const vespalib::eval::TypedCells& rhs) const override {
- return _df->calc(_lhs, rhs);
+ return _df.calc(_lhs, rhs);
}
double convert_threshold(double threshold) const override {
- return _df->convert_threshold(threshold);
+ return _df.convert_threshold(threshold);
}
double to_rawscore(double distance) const override {
- return _df->to_rawscore(distance);
+ return _df.to_rawscore(distance);
}
double calc_with_limit(const vespalib::eval::TypedCells& rhs, double limit) const override {
- return _df->calc_with_limit(_lhs, rhs, limit);
+ return _df.calc_with_limit(_lhs, rhs, limit);
}
};
@@ -81,11 +81,11 @@ public:
_df(std::move(df))
{}
- BoundDistanceFunction::UP forQueryVector(const vespalib::eval::TypedCells& lhs) override {
- return std::make_unique<SimpleBoundDistanceFunction>(lhs, _df.get());
+ BoundDistanceFunction::UP for_query_vector(const vespalib::eval::TypedCells& lhs) override {
+ return std::make_unique<SimpleBoundDistanceFunction>(lhs, *_df);
}
- BoundDistanceFunction::UP forInsertionVector(const vespalib::eval::TypedCells& lhs) override {
- return std::make_unique<SimpleBoundDistanceFunction>(lhs, _df.get());
+ BoundDistanceFunction::UP for_insertion_vector(const vespalib::eval::TypedCells& lhs) override {
+ return std::make_unique<SimpleBoundDistanceFunction>(lhs, *_df);
}
};
diff --git a/searchlib/src/vespa/searchlib/tensor/distance_function_factory.h b/searchlib/src/vespa/searchlib/tensor/distance_function_factory.h
index 76ed0e59358..1edb94bd7aa 100644
--- a/searchlib/src/vespa/searchlib/tensor/distance_function_factory.h
+++ b/searchlib/src/vespa/searchlib/tensor/distance_function_factory.h
@@ -9,12 +9,17 @@
namespace search::tensor {
+/**
+ * API for binding the LHS of a distance calculation
+ * This allows keeping global state in the factory itself, and state
+ * for one particular vector in the distance function object.
+ */
struct DistanceFunctionFactory {
const vespalib::eval::CellType expected_cell_type;
DistanceFunctionFactory(vespalib::eval::CellType ct) : expected_cell_type(ct) {}
virtual ~DistanceFunctionFactory() {}
- virtual BoundDistanceFunction::UP forQueryVector(const vespalib::eval::TypedCells& lhs) = 0;
- virtual BoundDistanceFunction::UP forInsertionVector(const vespalib::eval::TypedCells& lhs) = 0;
+ virtual BoundDistanceFunction::UP for_query_vector(const vespalib::eval::TypedCells& lhs) = 0;
+ virtual BoundDistanceFunction::UP for_insertion_vector(const vespalib::eval::TypedCells& lhs) = 0;
using UP = std::unique_ptr<DistanceFunctionFactory>;
};
diff --git a/searchlib/src/vespa/searchlib/tensor/hnsw_index.cpp b/searchlib/src/vespa/searchlib/tensor/hnsw_index.cpp
index c879aa13571..fa7f150fd89 100644
--- a/searchlib/src/vespa/searchlib/tensor/hnsw_index.cpp
+++ b/searchlib/src/vespa/searchlib/tensor/hnsw_index.cpp
@@ -278,7 +278,7 @@ double
HnswIndex<type>::calc_distance(uint32_t lhs_nodeid, uint32_t rhs_nodeid) const
{
auto lhs = get_vector(lhs_nodeid);
- auto df = _distance_ff->forInsertionVector(lhs);
+ auto df = _distance_ff->for_insertion_vector(lhs);
auto rhs = get_vector(rhs_nodeid);
return df->calc(rhs);
}
@@ -491,7 +491,7 @@ HnswIndex<type>::internal_prepare_add_node(PreparedAddDoc& op, TypedCells input_
return;
}
int search_level = entry.level;
- auto df = _distance_ff->forInsertionVector(input_vector);
+ auto df = _distance_ff->for_insertion_vector(input_vector);
double entry_dist = calc_distance(*df, entry.nodeid);
uint32_t entry_docid = get_docid(entry.nodeid);
// TODO: check if entry nodeid/levels_ref is still valid here