diff options
Diffstat (limited to 'searchlib/src/tests/queryeval/nearest_neighbor/nearest_neighbor_test.cpp')
-rw-r--r-- | searchlib/src/tests/queryeval/nearest_neighbor/nearest_neighbor_test.cpp | 18 |
1 files changed, 11 insertions, 7 deletions
diff --git a/searchlib/src/tests/queryeval/nearest_neighbor/nearest_neighbor_test.cpp b/searchlib/src/tests/queryeval/nearest_neighbor/nearest_neighbor_test.cpp index 5f4db88bf4c..1e341eab707 100644 --- a/searchlib/src/tests/queryeval/nearest_neighbor/nearest_neighbor_test.cpp +++ b/searchlib/src/tests/queryeval/nearest_neighbor/nearest_neighbor_test.cpp @@ -10,6 +10,7 @@ #include <vespa/searchlib/queryeval/nns_index_iterator.h> #include <vespa/searchlib/queryeval/simpleresult.h> #include <vespa/searchlib/tensor/dense_tensor_attribute.h> +#include <vespa/searchlib/tensor/distance_calculator.h> #include <vespa/searchlib/tensor/distance_function_factory.h> #include <vespa/vespalib/gtest/gtest.h> #include <vespa/vespalib/test/insertion_operators.h> @@ -25,6 +26,7 @@ using search::BitVector; using search::attribute::DistanceMetric; using search::feature_t; using search::tensor::DenseTensorAttribute; +using search::tensor::DistanceCalculator; using search::tensor::DistanceFunction; using vespalib::eval::CellType; using vespalib::eval::SimpleValue; @@ -111,11 +113,11 @@ struct Fixture setTensor(docId, *t); } - const DistanceFunction *dist_fun() const { + const DistanceFunction &dist_fun() const { if (_cfg.tensorType().cell_type() == CellType::FLOAT) { - return euclid_f.get(); + return *euclid_f; } else { - return euclid_d.get(); + return *euclid_d; } } }; @@ -125,10 +127,11 @@ SimpleResult find_matches(Fixture &env, const Value &qtv, double threshold = std auto md = MatchData::makeTestInstance(2, 2); auto &tfmd = *(md->resolveTermField(0)); auto &attr = *(env._tensorAttr); + DistanceCalculator dist_calc(attr, qtv, env.dist_fun()); NearestNeighborDistanceHeap dh(2); - dh.set_distance_threshold(env.dist_fun()->convert_threshold(threshold)); + dh.set_distance_threshold(env.dist_fun().convert_threshold(threshold)); const BitVector *filter = env._global_filter.get(); - auto search = NearestNeighborIterator::create(strict, tfmd, qtv, attr, dh, filter, env.dist_fun()); + auto search = NearestNeighborIterator::create(strict, tfmd, dist_calc, dh, filter); if (strict) { return SimpleResult().searchStrict(*search, attr.getNumDocs()); } else { @@ -217,8 +220,9 @@ std::vector<feature_t> get_rawscores(Fixture &env, const Value &qtv) { auto md = MatchData::makeTestInstance(2, 2); auto &tfmd = *(md->resolveTermField(0)); auto &attr = *(env._tensorAttr); + DistanceCalculator dist_calc(attr, qtv, env.dist_fun()); NearestNeighborDistanceHeap dh(2); - auto search = NearestNeighborIterator::create(strict, tfmd, qtv, attr, dh, nullptr, env.dist_fun()); + auto search = NearestNeighborIterator::create(strict, tfmd, dist_calc, dh, nullptr); uint32_t limit = attr.getNumDocs(); uint32_t docid = 1; search->initRange(docid, limit); @@ -268,7 +272,7 @@ TEST(NnsIndexIteratorTest, require_that_iterator_works_as_expected) { std::vector<NnsIndexIterator::Hit> hits{{2,4.0}, {3,9.0}, {5,1.0}, {8,16.0}, {9,36.0}}; auto md = MatchData::makeTestInstance(2, 2); auto &tfmd = *(md->resolveTermField(0)); - auto search = NnsIndexIterator::create(tfmd, hits, euclid_d.get()); + auto search = NnsIndexIterator::create(tfmd, hits, *euclid_d); uint32_t docid = 1; search->initFullRange(); bool match = search->seek(docid); |