summaryrefslogtreecommitdiffstats
path: root/searchlib/src/tests/queryeval/nearest_neighbor/nearest_neighbor_test.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'searchlib/src/tests/queryeval/nearest_neighbor/nearest_neighbor_test.cpp')
-rw-r--r--searchlib/src/tests/queryeval/nearest_neighbor/nearest_neighbor_test.cpp18
1 files changed, 11 insertions, 7 deletions
diff --git a/searchlib/src/tests/queryeval/nearest_neighbor/nearest_neighbor_test.cpp b/searchlib/src/tests/queryeval/nearest_neighbor/nearest_neighbor_test.cpp
index 5f4db88bf4c..1e341eab707 100644
--- a/searchlib/src/tests/queryeval/nearest_neighbor/nearest_neighbor_test.cpp
+++ b/searchlib/src/tests/queryeval/nearest_neighbor/nearest_neighbor_test.cpp
@@ -10,6 +10,7 @@
#include <vespa/searchlib/queryeval/nns_index_iterator.h>
#include <vespa/searchlib/queryeval/simpleresult.h>
#include <vespa/searchlib/tensor/dense_tensor_attribute.h>
+#include <vespa/searchlib/tensor/distance_calculator.h>
#include <vespa/searchlib/tensor/distance_function_factory.h>
#include <vespa/vespalib/gtest/gtest.h>
#include <vespa/vespalib/test/insertion_operators.h>
@@ -25,6 +26,7 @@ using search::BitVector;
using search::attribute::DistanceMetric;
using search::feature_t;
using search::tensor::DenseTensorAttribute;
+using search::tensor::DistanceCalculator;
using search::tensor::DistanceFunction;
using vespalib::eval::CellType;
using vespalib::eval::SimpleValue;
@@ -111,11 +113,11 @@ struct Fixture
setTensor(docId, *t);
}
- const DistanceFunction *dist_fun() const {
+ const DistanceFunction &dist_fun() const {
if (_cfg.tensorType().cell_type() == CellType::FLOAT) {
- return euclid_f.get();
+ return *euclid_f;
} else {
- return euclid_d.get();
+ return *euclid_d;
}
}
};
@@ -125,10 +127,11 @@ SimpleResult find_matches(Fixture &env, const Value &qtv, double threshold = std
auto md = MatchData::makeTestInstance(2, 2);
auto &tfmd = *(md->resolveTermField(0));
auto &attr = *(env._tensorAttr);
+ DistanceCalculator dist_calc(attr, qtv, env.dist_fun());
NearestNeighborDistanceHeap dh(2);
- dh.set_distance_threshold(env.dist_fun()->convert_threshold(threshold));
+ dh.set_distance_threshold(env.dist_fun().convert_threshold(threshold));
const BitVector *filter = env._global_filter.get();
- auto search = NearestNeighborIterator::create(strict, tfmd, qtv, attr, dh, filter, env.dist_fun());
+ auto search = NearestNeighborIterator::create(strict, tfmd, dist_calc, dh, filter);
if (strict) {
return SimpleResult().searchStrict(*search, attr.getNumDocs());
} else {
@@ -217,8 +220,9 @@ std::vector<feature_t> get_rawscores(Fixture &env, const Value &qtv) {
auto md = MatchData::makeTestInstance(2, 2);
auto &tfmd = *(md->resolveTermField(0));
auto &attr = *(env._tensorAttr);
+ DistanceCalculator dist_calc(attr, qtv, env.dist_fun());
NearestNeighborDistanceHeap dh(2);
- auto search = NearestNeighborIterator::create(strict, tfmd, qtv, attr, dh, nullptr, env.dist_fun());
+ auto search = NearestNeighborIterator::create(strict, tfmd, dist_calc, dh, nullptr);
uint32_t limit = attr.getNumDocs();
uint32_t docid = 1;
search->initRange(docid, limit);
@@ -268,7 +272,7 @@ TEST(NnsIndexIteratorTest, require_that_iterator_works_as_expected) {
std::vector<NnsIndexIterator::Hit> hits{{2,4.0}, {3,9.0}, {5,1.0}, {8,16.0}, {9,36.0}};
auto md = MatchData::makeTestInstance(2, 2);
auto &tfmd = *(md->resolveTermField(0));
- auto search = NnsIndexIterator::create(tfmd, hits, euclid_d.get());
+ auto search = NnsIndexIterator::create(tfmd, hits, *euclid_d);
uint32_t docid = 1;
search->initFullRange();
bool match = search->seek(docid);