summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorGeir Storli <geirst@yahooinc.com>2022-06-29 16:37:02 +0200
committerGitHub <noreply@github.com>2022-06-29 16:37:02 +0200
commit2edc796a4aae38fb6d468a69600da9ba07254fa5 (patch)
tree8bb2fd5411d2ca38fbdb43772c60ead6626b1750
parent5e2925ee24acfddbe14a8797917d18bcd7d9be26 (diff)
parent503e7815b43ada394d58ac6f6697ab767650683e (diff)
Merge pull request #23277 from vespa-engine/geirst/refactor-out-distance-calculator-class
Refactor out class to calculate the distance between attribute tensor…
-rw-r--r--searchlib/src/tests/queryeval/nearest_neighbor/nearest_neighbor_test.cpp18
-rw-r--r--searchlib/src/vespa/searchlib/queryeval/nearest_neighbor_blueprint.cpp67
-rw-r--r--searchlib/src/vespa/searchlib/queryeval/nearest_neighbor_blueprint.h8
-rw-r--r--searchlib/src/vespa/searchlib/queryeval/nearest_neighbor_iterator.cpp21
-rw-r--r--searchlib/src/vespa/searchlib/queryeval/nearest_neighbor_iterator.h27
-rw-r--r--searchlib/src/vespa/searchlib/queryeval/nns_index_iterator.cpp8
-rw-r--r--searchlib/src/vespa/searchlib/queryeval/nns_index_iterator.h2
-rw-r--r--searchlib/src/vespa/searchlib/tensor/CMakeLists.txt1
-rw-r--r--searchlib/src/vespa/searchlib/tensor/distance_calculator.cpp98
-rw-r--r--searchlib/src/vespa/searchlib/tensor/distance_calculator.h49
10 files changed, 196 insertions, 103 deletions
diff --git a/searchlib/src/tests/queryeval/nearest_neighbor/nearest_neighbor_test.cpp b/searchlib/src/tests/queryeval/nearest_neighbor/nearest_neighbor_test.cpp
index 5f4db88bf4c..1e341eab707 100644
--- a/searchlib/src/tests/queryeval/nearest_neighbor/nearest_neighbor_test.cpp
+++ b/searchlib/src/tests/queryeval/nearest_neighbor/nearest_neighbor_test.cpp
@@ -10,6 +10,7 @@
#include <vespa/searchlib/queryeval/nns_index_iterator.h>
#include <vespa/searchlib/queryeval/simpleresult.h>
#include <vespa/searchlib/tensor/dense_tensor_attribute.h>
+#include <vespa/searchlib/tensor/distance_calculator.h>
#include <vespa/searchlib/tensor/distance_function_factory.h>
#include <vespa/vespalib/gtest/gtest.h>
#include <vespa/vespalib/test/insertion_operators.h>
@@ -25,6 +26,7 @@ using search::BitVector;
using search::attribute::DistanceMetric;
using search::feature_t;
using search::tensor::DenseTensorAttribute;
+using search::tensor::DistanceCalculator;
using search::tensor::DistanceFunction;
using vespalib::eval::CellType;
using vespalib::eval::SimpleValue;
@@ -111,11 +113,11 @@ struct Fixture
setTensor(docId, *t);
}
- const DistanceFunction *dist_fun() const {
+ const DistanceFunction &dist_fun() const {
if (_cfg.tensorType().cell_type() == CellType::FLOAT) {
- return euclid_f.get();
+ return *euclid_f;
} else {
- return euclid_d.get();
+ return *euclid_d;
}
}
};
@@ -125,10 +127,11 @@ SimpleResult find_matches(Fixture &env, const Value &qtv, double threshold = std
auto md = MatchData::makeTestInstance(2, 2);
auto &tfmd = *(md->resolveTermField(0));
auto &attr = *(env._tensorAttr);
+ DistanceCalculator dist_calc(attr, qtv, env.dist_fun());
NearestNeighborDistanceHeap dh(2);
- dh.set_distance_threshold(env.dist_fun()->convert_threshold(threshold));
+ dh.set_distance_threshold(env.dist_fun().convert_threshold(threshold));
const BitVector *filter = env._global_filter.get();
- auto search = NearestNeighborIterator::create(strict, tfmd, qtv, attr, dh, filter, env.dist_fun());
+ auto search = NearestNeighborIterator::create(strict, tfmd, dist_calc, dh, filter);
if (strict) {
return SimpleResult().searchStrict(*search, attr.getNumDocs());
} else {
@@ -217,8 +220,9 @@ std::vector<feature_t> get_rawscores(Fixture &env, const Value &qtv) {
auto md = MatchData::makeTestInstance(2, 2);
auto &tfmd = *(md->resolveTermField(0));
auto &attr = *(env._tensorAttr);
+ DistanceCalculator dist_calc(attr, qtv, env.dist_fun());
NearestNeighborDistanceHeap dh(2);
- auto search = NearestNeighborIterator::create(strict, tfmd, qtv, attr, dh, nullptr, env.dist_fun());
+ auto search = NearestNeighborIterator::create(strict, tfmd, dist_calc, dh, nullptr);
uint32_t limit = attr.getNumDocs();
uint32_t docid = 1;
search->initRange(docid, limit);
@@ -268,7 +272,7 @@ TEST(NnsIndexIteratorTest, require_that_iterator_works_as_expected) {
std::vector<NnsIndexIterator::Hit> hits{{2,4.0}, {3,9.0}, {5,1.0}, {8,16.0}, {9,36.0}};
auto md = MatchData::makeTestInstance(2, 2);
auto &tfmd = *(md->resolveTermField(0));
- auto search = NnsIndexIterator::create(tfmd, hits, euclid_d.get());
+ auto search = NnsIndexIterator::create(tfmd, hits, *euclid_d);
uint32_t docid = 1;
search->initFullRange();
bool match = search->seek(docid);
diff --git a/searchlib/src/vespa/searchlib/queryeval/nearest_neighbor_blueprint.cpp b/searchlib/src/vespa/searchlib/queryeval/nearest_neighbor_blueprint.cpp
index 8c03800b92a..8aa806b01cd 100644
--- a/searchlib/src/vespa/searchlib/queryeval/nearest_neighbor_blueprint.cpp
+++ b/searchlib/src/vespa/searchlib/queryeval/nearest_neighbor_blueprint.cpp
@@ -4,7 +4,6 @@
#include "nearest_neighbor_blueprint.h"
#include "nearest_neighbor_iterator.h"
#include "nns_index_iterator.h"
-#include <vespa/eval/eval/fast_value.h>
#include <vespa/searchlib/fef/termfieldmatchdataarray.h>
#include <vespa/searchlib/tensor/dense_tensor_attribute.h>
#include <vespa/searchlib/tensor/distance_function_factory.h>
@@ -13,45 +12,12 @@
LOG_SETUP(".searchlib.queryeval.nearest_neighbor_blueprint");
-using vespalib::eval::CellType;
-using vespalib::eval::FastValueBuilderFactory;
-using vespalib::eval::TypedCells;
using vespalib::eval::Value;
-using vespalib::eval::ValueType;
namespace search::queryeval {
namespace {
-template<typename LCT, typename RCT>
-std::unique_ptr<Value>
-convert_cells(const ValueType &new_type, std::unique_ptr<Value> old_value)
-{
- auto old_cells = old_value->cells().typify<LCT>();
- auto builder = FastValueBuilderFactory::get().create_value_builder<RCT>(new_type);
- auto new_cells = builder->add_subspace();
- assert(old_cells.size() == new_cells.size());
- auto p = new_cells.begin();
- for (LCT value : old_cells) {
- RCT conv(value);
- *p++ = conv;
- }
- return builder->build(std::move(builder));
-}
-
-struct ConvertCellsSelector
-{
- template <typename LCT, typename RCT>
- static auto invoke(const ValueType &new_type, std::unique_ptr<Value> old_value) {
- return convert_cells<LCT, RCT>(new_type, std::move(old_value));
- }
- auto operator() (CellType from, CellType to, std::unique_ptr<Value> old_value) const {
- using MyTypify = vespalib::eval::TypifyCellType;
- ValueType new_type = old_value->type().cell_cast(to);
- return vespalib::typify_invoke<2,MyTypify,ConvertCellsSelector>(from, to, new_type, std::move(old_value));
- }
-};
-
vespalib::string
to_string(NearestNeighborBlueprint::Algorithm algorithm)
{
@@ -78,7 +44,8 @@ NearestNeighborBlueprint::NearestNeighborBlueprint(const queryeval::FieldSpec& f
double global_filter_upper_limit)
: ComplexLeafBlueprint(field),
_attr_tensor(attr_tensor),
- _query_tensor(std::move(query_tensor)),
+ _distance_calc(_attr_tensor, std::move(query_tensor)),
+ _query_tensor(_distance_calc.query_tensor()),
_target_hits(target_hits),
_adjusted_target_hits(target_hits),
_approximate(approximate),
@@ -86,7 +53,6 @@ NearestNeighborBlueprint::NearestNeighborBlueprint(const queryeval::FieldSpec& f
_distance_threshold(std::numeric_limits<double>::max()),
_global_filter_lower_limit(global_filter_lower_limit),
_global_filter_upper_limit(global_filter_upper_limit),
- _fallback_dist_fun(),
_distance_heap(target_hits),
_found_hits(),
_algorithm(Algorithm::EXACT),
@@ -95,27 +61,13 @@ NearestNeighborBlueprint::NearestNeighborBlueprint(const queryeval::FieldSpec& f
_global_filter_hits(),
_global_filter_hit_ratio()
{
- CellType attr_ct = _attr_tensor.getTensorType().cell_type();
- _fallback_dist_fun = search::tensor::make_distance_function(_attr_tensor.distance_metric(), attr_ct);
- _dist_fun = _fallback_dist_fun.get();
- assert(_dist_fun);
- auto nns_index = _attr_tensor.nearest_neighbor_index();
- if (nns_index) {
- _dist_fun = nns_index->distance_function();
- assert(_dist_fun);
- }
- auto query_ct = _query_tensor->cells().type;
- CellType required_ct = _dist_fun->expected_cell_type();
- if (query_ct != required_ct) {
- ConvertCellsSelector converter;
- _query_tensor = converter(query_ct, required_ct, std::move(_query_tensor));
- }
if (distance_threshold < std::numeric_limits<double>::max()) {
- _distance_threshold = _dist_fun->convert_threshold(distance_threshold);
+ _distance_threshold = _distance_calc.function().convert_threshold(distance_threshold);
_distance_heap.set_distance_threshold(_distance_threshold);
}
uint32_t est_hits = _attr_tensor.get_num_docs();
setEstimate(HitEstimate(est_hits, false));
+ auto nns_index = _attr_tensor.nearest_neighbor_index();
set_want_global_filter(nns_index && _approximate);
}
@@ -155,7 +107,7 @@ NearestNeighborBlueprint::set_global_filter(const GlobalFilter &global_filter, d
void
NearestNeighborBlueprint::perform_top_k(const search::tensor::NearestNeighborIndex* nns_index)
{
- auto lhs = _query_tensor->cells();
+ auto lhs = _query_tensor.cells();
uint32_t k = _adjusted_target_hits;
if (_global_filter->has_filter()) {
auto filter = _global_filter->filter();
@@ -175,13 +127,12 @@ NearestNeighborBlueprint::createLeafSearch(const search::fef::TermFieldMatchData
switch (_algorithm) {
case Algorithm::INDEX_TOP_K_WITH_FILTER:
case Algorithm::INDEX_TOP_K:
- return NnsIndexIterator::create(tfmd, _found_hits, _dist_fun);
+ return NnsIndexIterator::create(tfmd, _found_hits, _distance_calc.function());
default:
;
}
- const Value &qT = *_query_tensor;
- return NearestNeighborIterator::create(strict, tfmd, qT, _attr_tensor,
- _distance_heap, _global_filter->filter(), _dist_fun);
+ return NearestNeighborIterator::create(strict, tfmd, _distance_calc,
+ _distance_heap, _global_filter->filter());
}
void
@@ -189,7 +140,7 @@ NearestNeighborBlueprint::visitMembers(vespalib::ObjectVisitor& visitor) const
{
ComplexLeafBlueprint::visitMembers(visitor);
visitor.visitString("attribute_tensor", _attr_tensor.getTensorType().to_spec());
- visitor.visitString("query_tensor", _query_tensor->type().to_spec());
+ visitor.visitString("query_tensor", _query_tensor.type().to_spec());
visitor.visitInt("target_hits", _target_hits);
visitor.visitInt("adjusted_target_hits", _adjusted_target_hits);
visitor.visitInt("explore_additional_hits", _explore_additional_hits);
diff --git a/searchlib/src/vespa/searchlib/queryeval/nearest_neighbor_blueprint.h b/searchlib/src/vespa/searchlib/queryeval/nearest_neighbor_blueprint.h
index 16b0e13014e..3be7d7fd01f 100644
--- a/searchlib/src/vespa/searchlib/queryeval/nearest_neighbor_blueprint.h
+++ b/searchlib/src/vespa/searchlib/queryeval/nearest_neighbor_blueprint.h
@@ -3,6 +3,7 @@
#include "blueprint.h"
#include "nearest_neighbor_distance_heap.h"
+#include <vespa/searchlib/tensor/distance_calculator.h>
#include <vespa/searchlib/tensor/distance_function.h>
#include <vespa/searchlib/tensor/nearest_neighbor_index.h>
#include <optional>
@@ -28,7 +29,8 @@ public:
};
private:
const tensor::ITensorAttribute& _attr_tensor;
- std::unique_ptr<vespalib::eval::Value> _query_tensor;
+ search::tensor::DistanceCalculator _distance_calc;
+ const vespalib::eval::Value& _query_tensor;
uint32_t _target_hits;
uint32_t _adjusted_target_hits;
bool _approximate;
@@ -36,8 +38,6 @@ private:
double _distance_threshold;
double _global_filter_lower_limit;
double _global_filter_upper_limit;
- search::tensor::DistanceFunction::UP _fallback_dist_fun;
- const search::tensor::DistanceFunction *_dist_fun;
mutable NearestNeighborDistanceHeap _distance_heap;
std::vector<search::tensor::NearestNeighborIndex::Neighbor> _found_hits;
Algorithm _algorithm;
@@ -59,7 +59,7 @@ public:
NearestNeighborBlueprint& operator=(const NearestNeighborBlueprint&) = delete;
~NearestNeighborBlueprint();
const tensor::ITensorAttribute& get_attribute_tensor() const { return _attr_tensor; }
- const vespalib::eval::Value& get_query_tensor() const { return *_query_tensor; }
+ const vespalib::eval::Value& get_query_tensor() const { return _query_tensor; }
uint32_t get_target_hits() const { return _target_hits; }
uint32_t get_adjusted_target_hits() const { return _adjusted_target_hits; }
void set_global_filter(const GlobalFilter &global_filter, double estimated_hit_ratio) override;
diff --git a/searchlib/src/vespa/searchlib/queryeval/nearest_neighbor_iterator.cpp b/searchlib/src/vespa/searchlib/queryeval/nearest_neighbor_iterator.cpp
index 6a00568bd06..e06fcc614d8 100644
--- a/searchlib/src/vespa/searchlib/queryeval/nearest_neighbor_iterator.cpp
+++ b/searchlib/src/vespa/searchlib/queryeval/nearest_neighbor_iterator.cpp
@@ -2,6 +2,8 @@
#include "nearest_neighbor_iterator.h"
#include <vespa/searchlib/common/bitvector.h>
+#include <vespa/searchlib/tensor/distance_calculator.h>
+#include <vespa/searchlib/tensor/distance_function.h>
using search::tensor::ITensorAttribute;
using vespalib::ConstArrayRef;
@@ -34,11 +36,10 @@ public:
NearestNeighborImpl(Params params_in)
: NearestNeighborIterator(params_in),
- _lhs(params().queryTensor.cells()),
_lastScore(0.0)
{
- assert(is_compatible(params().tensorAttribute.getTensorType(),
- params().queryTensor.type()));
+ assert(is_compatible(params().distance_calc.attribute_tensor().getTensorType(),
+ params().distance_calc.query_tensor().type()));
}
~NearestNeighborImpl();
@@ -64,7 +65,7 @@ public:
}
void doUnpack(uint32_t docId) override {
- double score = params().distanceFunction->to_rawscore(_lastScore);
+ double score = params().distance_calc.function().to_rawscore(_lastScore);
params().tfmd.setRawScore(docId, score);
params().distanceHeap.used(_lastScore);
}
@@ -73,11 +74,9 @@ public:
private:
double computeDistance(uint32_t docId, double limit) {
- auto rhs = params().tensorAttribute.extract_cells_ref(docId);
- return params().distanceFunction->calc_with_limit(_lhs, rhs, limit);
+ return params().distance_calc.calc_with_limit(docId, limit);
}
- TypedCells _lhs;
double _lastScore;
};
@@ -105,14 +104,12 @@ std::unique_ptr<NearestNeighborIterator>
NearestNeighborIterator::create(
bool strict,
fef::TermFieldMatchData &tfmd,
- const vespalib::eval::Value &queryTensor,
- const search::tensor::ITensorAttribute &tensorAttribute,
+ const search::tensor::DistanceCalculator &distance_calc,
NearestNeighborDistanceHeap &distanceHeap,
- const search::BitVector *filter,
- const search::tensor::DistanceFunction *dist_fun)
+ const search::BitVector *filter)
{
- Params params(tfmd, queryTensor, tensorAttribute, distanceHeap, filter, dist_fun);
+ Params params(tfmd, distance_calc, distanceHeap, filter);
if (filter) {
return resolve_strict<true>(strict, params);
} else {
diff --git a/searchlib/src/vespa/searchlib/queryeval/nearest_neighbor_iterator.h b/searchlib/src/vespa/searchlib/queryeval/nearest_neighbor_iterator.h
index 66622288d84..0d8f70d15c2 100644
--- a/searchlib/src/vespa/searchlib/queryeval/nearest_neighbor_iterator.h
+++ b/searchlib/src/vespa/searchlib/queryeval/nearest_neighbor_iterator.h
@@ -7,10 +7,11 @@
#include <vespa/eval/eval/value.h>
#include <vespa/searchlib/fef/termfieldmatchdata.h>
#include <vespa/searchlib/tensor/i_tensor_attribute.h>
-#include <vespa/searchlib/tensor/distance_function.h>
#include <vespa/vespalib/util/priority_queue.h>
#include <cmath>
+namespace search::tensor { class DistanceCalculator; }
+
namespace search::queryeval {
class NearestNeighborIterator : public SearchIterator
@@ -21,24 +22,18 @@ public:
struct Params {
fef::TermFieldMatchData &tfmd;
- const Value &queryTensor;
- const ITensorAttribute &tensorAttribute;
+ const search::tensor::DistanceCalculator &distance_calc;
NearestNeighborDistanceHeap &distanceHeap;
const search::BitVector *filter;
- const search::tensor::DistanceFunction *distanceFunction;
-
+
Params(fef::TermFieldMatchData &tfmd_in,
- const Value &queryTensor_in,
- const ITensorAttribute &tensorAttribute_in,
+ const search::tensor::DistanceCalculator &distance_calc_in,
NearestNeighborDistanceHeap &distanceHeap_in,
- const search::BitVector *filter_in,
- const search::tensor::DistanceFunction *distanceFunction_in)
+ const search::BitVector *filter_in)
: tfmd(tfmd_in),
- queryTensor(queryTensor_in),
- tensorAttribute(tensorAttribute_in),
+ distance_calc(distance_calc_in),
distanceHeap(distanceHeap_in),
- filter(filter_in),
- distanceFunction(distanceFunction_in)
+ filter(filter_in)
{}
};
@@ -49,11 +44,9 @@ public:
static std::unique_ptr<NearestNeighborIterator> create(
bool strict,
fef::TermFieldMatchData &tfmd,
- const Value &queryTensor,
- const search::tensor::ITensorAttribute &tensorAttribute,
+ const search::tensor::DistanceCalculator &distance_calc,
NearestNeighborDistanceHeap &distanceHeap,
- const search::BitVector *filter,
- const search::tensor::DistanceFunction *dist_fun);
+ const search::BitVector *filter);
const Params& params() const { return _params; }
private:
diff --git a/searchlib/src/vespa/searchlib/queryeval/nns_index_iterator.cpp b/searchlib/src/vespa/searchlib/queryeval/nns_index_iterator.cpp
index cd65f01025b..95264a79431 100644
--- a/searchlib/src/vespa/searchlib/queryeval/nns_index_iterator.cpp
+++ b/searchlib/src/vespa/searchlib/queryeval/nns_index_iterator.cpp
@@ -18,13 +18,13 @@ class NeighborVectorIterator : public NnsIndexIterator
private:
fef::TermFieldMatchData &_tfmd;
const std::vector<Neighbor> &_hits;
- const search::tensor::DistanceFunction * const _dist_fun;
+ const search::tensor::DistanceFunction &_dist_fun;
uint32_t _idx;
double _last_abstract_dist;
public:
NeighborVectorIterator(fef::TermFieldMatchData &tfmd,
const std::vector<Neighbor> &hits,
- const search::tensor::DistanceFunction *dist_fun)
+ const search::tensor::DistanceFunction &dist_fun)
: _tfmd(tfmd),
_hits(hits),
_dist_fun(dist_fun),
@@ -54,7 +54,7 @@ public:
}
void doUnpack(uint32_t docId) override {
- double score = _dist_fun->to_rawscore(_last_abstract_dist);
+ double score = _dist_fun.to_rawscore(_last_abstract_dist);
_tfmd.setRawScore(docId, score);
}
@@ -65,7 +65,7 @@ std::unique_ptr<NnsIndexIterator>
NnsIndexIterator::create(
fef::TermFieldMatchData &tfmd,
const std::vector<Neighbor> &hits,
- const search::tensor::DistanceFunction *dist_fun)
+ const search::tensor::DistanceFunction &dist_fun)
{
return std::make_unique<NeighborVectorIterator>(tfmd, hits, dist_fun);
}
diff --git a/searchlib/src/vespa/searchlib/queryeval/nns_index_iterator.h b/searchlib/src/vespa/searchlib/queryeval/nns_index_iterator.h
index 019ac8579bd..031a603de49 100644
--- a/searchlib/src/vespa/searchlib/queryeval/nns_index_iterator.h
+++ b/searchlib/src/vespa/searchlib/queryeval/nns_index_iterator.h
@@ -16,7 +16,7 @@ public:
static std::unique_ptr<NnsIndexIterator> create(
fef::TermFieldMatchData &tfmd,
const std::vector<Hit> &hits,
- const search::tensor::DistanceFunction *dist_fun);
+ const search::tensor::DistanceFunction &dist_fun);
};
} // namespace
diff --git a/searchlib/src/vespa/searchlib/tensor/CMakeLists.txt b/searchlib/src/vespa/searchlib/tensor/CMakeLists.txt
index ae34cdd66c8..9e0ccb8d37a 100644
--- a/searchlib/src/vespa/searchlib/tensor/CMakeLists.txt
+++ b/searchlib/src/vespa/searchlib/tensor/CMakeLists.txt
@@ -11,6 +11,7 @@ vespa_add_library(searchlib_tensor OBJECT
direct_tensor_attribute.cpp
direct_tensor_saver.cpp
direct_tensor_store.cpp
+ distance_calculator.cpp
distance_function_factory.cpp
euclidean_distance.cpp
geo_degrees_distance.cpp
diff --git a/searchlib/src/vespa/searchlib/tensor/distance_calculator.cpp b/searchlib/src/vespa/searchlib/tensor/distance_calculator.cpp
new file mode 100644
index 00000000000..6bb3d9ed49b
--- /dev/null
+++ b/searchlib/src/vespa/searchlib/tensor/distance_calculator.cpp
@@ -0,0 +1,98 @@
+// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+
+#include "distance_calculator.h"
+#include "distance_function_factory.h"
+#include "i_tensor_attribute.h"
+#include "nearest_neighbor_index.h"
+#include <vespa/eval/eval/fast_value.h>
+
+using vespalib::eval::CellType;
+using vespalib::eval::FastValueBuilderFactory;
+using vespalib::eval::TypedCells;
+using vespalib::eval::Value;
+using vespalib::eval::ValueType;
+
+namespace {
+
+template<typename LCT, typename RCT>
+std::unique_ptr<Value>
+convert_cells(const ValueType& new_type, std::unique_ptr<Value> old_value)
+{
+ auto old_cells = old_value->cells().typify<LCT>();
+ auto builder = FastValueBuilderFactory::get().create_value_builder<RCT>(new_type);
+ auto new_cells = builder->add_subspace();
+ assert(old_cells.size() == new_cells.size());
+ auto p = new_cells.begin();
+ for (LCT value : old_cells) {
+ RCT conv(value);
+ *p++ = conv;
+ }
+ return builder->build(std::move(builder));
+}
+
+struct ConvertCellsSelector
+{
+ template <typename LCT, typename RCT>
+ static auto invoke(const ValueType& new_type, std::unique_ptr<Value> old_value) {
+ return convert_cells<LCT, RCT>(new_type, std::move(old_value));
+ }
+ auto operator() (CellType from, CellType to, std::unique_ptr<Value> old_value) const {
+ using MyTypify = vespalib::eval::TypifyCellType;
+ ValueType new_type = old_value->type().cell_cast(to);
+ return vespalib::typify_invoke<2,MyTypify,ConvertCellsSelector>(from, to, new_type, std::move(old_value));
+ }
+};
+
+}
+
+namespace search::tensor {
+
+DistanceCalculator::DistanceCalculator(const tensor::ITensorAttribute& attr_tensor,
+ std::unique_ptr<vespalib::eval::Value> query_tensor_in)
+ : _attr_tensor(attr_tensor),
+ _query_tensor_uptr(std::move(query_tensor_in)),
+ _query_tensor(),
+ _query_tensor_cells(),
+ _dist_fun_uptr(make_distance_function(_attr_tensor.distance_metric(),
+ _attr_tensor.getTensorType().cell_type())),
+ _dist_fun(_dist_fun_uptr.get())
+{
+ assert(_dist_fun);
+ auto nns_index = _attr_tensor.nearest_neighbor_index();
+ if (nns_index) {
+ _dist_fun = nns_index->distance_function();
+ assert(_dist_fun);
+ }
+ auto query_ct = _query_tensor_uptr->cells().type;
+ CellType required_ct = _dist_fun->expected_cell_type();
+ if (query_ct != required_ct) {
+ ConvertCellsSelector converter;
+ _query_tensor_uptr = converter(query_ct, required_ct, std::move(_query_tensor_uptr));
+ }
+ _query_tensor = _query_tensor_uptr.get();
+ _query_tensor_cells = _query_tensor->cells();
+}
+
+DistanceCalculator::DistanceCalculator(const tensor::ITensorAttribute& attr_tensor,
+ const vespalib::eval::Value& query_tensor_in,
+ const DistanceFunction& function_in)
+ : _attr_tensor(attr_tensor),
+ _query_tensor_uptr(),
+ _query_tensor(&query_tensor_in),
+ _query_tensor_cells(_query_tensor->cells()),
+ _dist_fun_uptr(),
+ _dist_fun(&function_in)
+{
+}
+
+DistanceCalculator::~DistanceCalculator() = default;
+
+double
+DistanceCalculator::calc_with_limit(uint32_t docid, double limit) const
+{
+ auto rhs = _attr_tensor.extract_cells_ref(docid);
+ return _dist_fun->calc_with_limit(_query_tensor_cells, rhs, limit);
+}
+
+}
+
diff --git a/searchlib/src/vespa/searchlib/tensor/distance_calculator.h b/searchlib/src/vespa/searchlib/tensor/distance_calculator.h
new file mode 100644
index 00000000000..df9344a24d1
--- /dev/null
+++ b/searchlib/src/vespa/searchlib/tensor/distance_calculator.h
@@ -0,0 +1,49 @@
+// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+#pragma once
+
+#include <vespa/eval/eval/typed_cells.h>
+#include <memory>
+
+namespace vespalib::eval { struct Value; }
+
+namespace search::tensor {
+
+class DistanceFunction;
+class ITensorAttribute;
+
+/**
+ * Class used to calculate the distance between two n-dimensional vectors,
+ * where one is stored in a TensorAttribute and the other comes from the query.
+ *
+ * The distance function to use is defined in the TensorAttribute.
+ */
+class DistanceCalculator {
+private:
+ const tensor::ITensorAttribute& _attr_tensor;
+ std::unique_ptr<vespalib::eval::Value> _query_tensor_uptr;
+ const vespalib::eval::Value* _query_tensor;
+ vespalib::eval::TypedCells _query_tensor_cells;
+ std::unique_ptr<DistanceFunction> _dist_fun_uptr;
+ const DistanceFunction* _dist_fun;
+
+public:
+ DistanceCalculator(const tensor::ITensorAttribute& attr_tensor,
+ std::unique_ptr<vespalib::eval::Value> query_tensor_in);
+
+ /**
+ * Only used by unit tests where ownership of query tensor and distance function is handled outside.
+ */
+ DistanceCalculator(const tensor::ITensorAttribute& attr_tensor,
+ const vespalib::eval::Value& query_tensor_in,
+ const DistanceFunction& function_in);
+
+ ~DistanceCalculator();
+
+ const tensor::ITensorAttribute& attribute_tensor() const { return _attr_tensor; }
+ const vespalib::eval::Value& query_tensor() const { return *_query_tensor; }
+ const DistanceFunction& function() const { return *_dist_fun; }
+
+ double calc_with_limit(uint32_t docid, double limit) const;
+};
+
+}