diff options
10 files changed, 196 insertions, 103 deletions
diff --git a/searchlib/src/tests/queryeval/nearest_neighbor/nearest_neighbor_test.cpp b/searchlib/src/tests/queryeval/nearest_neighbor/nearest_neighbor_test.cpp index 5f4db88bf4c..1e341eab707 100644 --- a/searchlib/src/tests/queryeval/nearest_neighbor/nearest_neighbor_test.cpp +++ b/searchlib/src/tests/queryeval/nearest_neighbor/nearest_neighbor_test.cpp @@ -10,6 +10,7 @@ #include <vespa/searchlib/queryeval/nns_index_iterator.h> #include <vespa/searchlib/queryeval/simpleresult.h> #include <vespa/searchlib/tensor/dense_tensor_attribute.h> +#include <vespa/searchlib/tensor/distance_calculator.h> #include <vespa/searchlib/tensor/distance_function_factory.h> #include <vespa/vespalib/gtest/gtest.h> #include <vespa/vespalib/test/insertion_operators.h> @@ -25,6 +26,7 @@ using search::BitVector; using search::attribute::DistanceMetric; using search::feature_t; using search::tensor::DenseTensorAttribute; +using search::tensor::DistanceCalculator; using search::tensor::DistanceFunction; using vespalib::eval::CellType; using vespalib::eval::SimpleValue; @@ -111,11 +113,11 @@ struct Fixture setTensor(docId, *t); } - const DistanceFunction *dist_fun() const { + const DistanceFunction &dist_fun() const { if (_cfg.tensorType().cell_type() == CellType::FLOAT) { - return euclid_f.get(); + return *euclid_f; } else { - return euclid_d.get(); + return *euclid_d; } } }; @@ -125,10 +127,11 @@ SimpleResult find_matches(Fixture &env, const Value &qtv, double threshold = std auto md = MatchData::makeTestInstance(2, 2); auto &tfmd = *(md->resolveTermField(0)); auto &attr = *(env._tensorAttr); + DistanceCalculator dist_calc(attr, qtv, env.dist_fun()); NearestNeighborDistanceHeap dh(2); - dh.set_distance_threshold(env.dist_fun()->convert_threshold(threshold)); + dh.set_distance_threshold(env.dist_fun().convert_threshold(threshold)); const BitVector *filter = env._global_filter.get(); - auto search = NearestNeighborIterator::create(strict, tfmd, qtv, attr, dh, filter, env.dist_fun()); + auto search = NearestNeighborIterator::create(strict, tfmd, dist_calc, dh, filter); if (strict) { return SimpleResult().searchStrict(*search, attr.getNumDocs()); } else { @@ -217,8 +220,9 @@ std::vector<feature_t> get_rawscores(Fixture &env, const Value &qtv) { auto md = MatchData::makeTestInstance(2, 2); auto &tfmd = *(md->resolveTermField(0)); auto &attr = *(env._tensorAttr); + DistanceCalculator dist_calc(attr, qtv, env.dist_fun()); NearestNeighborDistanceHeap dh(2); - auto search = NearestNeighborIterator::create(strict, tfmd, qtv, attr, dh, nullptr, env.dist_fun()); + auto search = NearestNeighborIterator::create(strict, tfmd, dist_calc, dh, nullptr); uint32_t limit = attr.getNumDocs(); uint32_t docid = 1; search->initRange(docid, limit); @@ -268,7 +272,7 @@ TEST(NnsIndexIteratorTest, require_that_iterator_works_as_expected) { std::vector<NnsIndexIterator::Hit> hits{{2,4.0}, {3,9.0}, {5,1.0}, {8,16.0}, {9,36.0}}; auto md = MatchData::makeTestInstance(2, 2); auto &tfmd = *(md->resolveTermField(0)); - auto search = NnsIndexIterator::create(tfmd, hits, euclid_d.get()); + auto search = NnsIndexIterator::create(tfmd, hits, *euclid_d); uint32_t docid = 1; search->initFullRange(); bool match = search->seek(docid); diff --git a/searchlib/src/vespa/searchlib/queryeval/nearest_neighbor_blueprint.cpp b/searchlib/src/vespa/searchlib/queryeval/nearest_neighbor_blueprint.cpp index 8c03800b92a..8aa806b01cd 100644 --- a/searchlib/src/vespa/searchlib/queryeval/nearest_neighbor_blueprint.cpp +++ b/searchlib/src/vespa/searchlib/queryeval/nearest_neighbor_blueprint.cpp @@ -4,7 +4,6 @@ #include "nearest_neighbor_blueprint.h" #include "nearest_neighbor_iterator.h" #include "nns_index_iterator.h" -#include <vespa/eval/eval/fast_value.h> #include <vespa/searchlib/fef/termfieldmatchdataarray.h> #include <vespa/searchlib/tensor/dense_tensor_attribute.h> #include <vespa/searchlib/tensor/distance_function_factory.h> @@ -13,45 +12,12 @@ LOG_SETUP(".searchlib.queryeval.nearest_neighbor_blueprint"); -using vespalib::eval::CellType; -using vespalib::eval::FastValueBuilderFactory; -using vespalib::eval::TypedCells; using vespalib::eval::Value; -using vespalib::eval::ValueType; namespace search::queryeval { namespace { -template<typename LCT, typename RCT> -std::unique_ptr<Value> -convert_cells(const ValueType &new_type, std::unique_ptr<Value> old_value) -{ - auto old_cells = old_value->cells().typify<LCT>(); - auto builder = FastValueBuilderFactory::get().create_value_builder<RCT>(new_type); - auto new_cells = builder->add_subspace(); - assert(old_cells.size() == new_cells.size()); - auto p = new_cells.begin(); - for (LCT value : old_cells) { - RCT conv(value); - *p++ = conv; - } - return builder->build(std::move(builder)); -} - -struct ConvertCellsSelector -{ - template <typename LCT, typename RCT> - static auto invoke(const ValueType &new_type, std::unique_ptr<Value> old_value) { - return convert_cells<LCT, RCT>(new_type, std::move(old_value)); - } - auto operator() (CellType from, CellType to, std::unique_ptr<Value> old_value) const { - using MyTypify = vespalib::eval::TypifyCellType; - ValueType new_type = old_value->type().cell_cast(to); - return vespalib::typify_invoke<2,MyTypify,ConvertCellsSelector>(from, to, new_type, std::move(old_value)); - } -}; - vespalib::string to_string(NearestNeighborBlueprint::Algorithm algorithm) { @@ -78,7 +44,8 @@ NearestNeighborBlueprint::NearestNeighborBlueprint(const queryeval::FieldSpec& f double global_filter_upper_limit) : ComplexLeafBlueprint(field), _attr_tensor(attr_tensor), - _query_tensor(std::move(query_tensor)), + _distance_calc(_attr_tensor, std::move(query_tensor)), + _query_tensor(_distance_calc.query_tensor()), _target_hits(target_hits), _adjusted_target_hits(target_hits), _approximate(approximate), @@ -86,7 +53,6 @@ NearestNeighborBlueprint::NearestNeighborBlueprint(const queryeval::FieldSpec& f _distance_threshold(std::numeric_limits<double>::max()), _global_filter_lower_limit(global_filter_lower_limit), _global_filter_upper_limit(global_filter_upper_limit), - _fallback_dist_fun(), _distance_heap(target_hits), _found_hits(), _algorithm(Algorithm::EXACT), @@ -95,27 +61,13 @@ NearestNeighborBlueprint::NearestNeighborBlueprint(const queryeval::FieldSpec& f _global_filter_hits(), _global_filter_hit_ratio() { - CellType attr_ct = _attr_tensor.getTensorType().cell_type(); - _fallback_dist_fun = search::tensor::make_distance_function(_attr_tensor.distance_metric(), attr_ct); - _dist_fun = _fallback_dist_fun.get(); - assert(_dist_fun); - auto nns_index = _attr_tensor.nearest_neighbor_index(); - if (nns_index) { - _dist_fun = nns_index->distance_function(); - assert(_dist_fun); - } - auto query_ct = _query_tensor->cells().type; - CellType required_ct = _dist_fun->expected_cell_type(); - if (query_ct != required_ct) { - ConvertCellsSelector converter; - _query_tensor = converter(query_ct, required_ct, std::move(_query_tensor)); - } if (distance_threshold < std::numeric_limits<double>::max()) { - _distance_threshold = _dist_fun->convert_threshold(distance_threshold); + _distance_threshold = _distance_calc.function().convert_threshold(distance_threshold); _distance_heap.set_distance_threshold(_distance_threshold); } uint32_t est_hits = _attr_tensor.get_num_docs(); setEstimate(HitEstimate(est_hits, false)); + auto nns_index = _attr_tensor.nearest_neighbor_index(); set_want_global_filter(nns_index && _approximate); } @@ -155,7 +107,7 @@ NearestNeighborBlueprint::set_global_filter(const GlobalFilter &global_filter, d void NearestNeighborBlueprint::perform_top_k(const search::tensor::NearestNeighborIndex* nns_index) { - auto lhs = _query_tensor->cells(); + auto lhs = _query_tensor.cells(); uint32_t k = _adjusted_target_hits; if (_global_filter->has_filter()) { auto filter = _global_filter->filter(); @@ -175,13 +127,12 @@ NearestNeighborBlueprint::createLeafSearch(const search::fef::TermFieldMatchData switch (_algorithm) { case Algorithm::INDEX_TOP_K_WITH_FILTER: case Algorithm::INDEX_TOP_K: - return NnsIndexIterator::create(tfmd, _found_hits, _dist_fun); + return NnsIndexIterator::create(tfmd, _found_hits, _distance_calc.function()); default: ; } - const Value &qT = *_query_tensor; - return NearestNeighborIterator::create(strict, tfmd, qT, _attr_tensor, - _distance_heap, _global_filter->filter(), _dist_fun); + return NearestNeighborIterator::create(strict, tfmd, _distance_calc, + _distance_heap, _global_filter->filter()); } void @@ -189,7 +140,7 @@ NearestNeighborBlueprint::visitMembers(vespalib::ObjectVisitor& visitor) const { ComplexLeafBlueprint::visitMembers(visitor); visitor.visitString("attribute_tensor", _attr_tensor.getTensorType().to_spec()); - visitor.visitString("query_tensor", _query_tensor->type().to_spec()); + visitor.visitString("query_tensor", _query_tensor.type().to_spec()); visitor.visitInt("target_hits", _target_hits); visitor.visitInt("adjusted_target_hits", _adjusted_target_hits); visitor.visitInt("explore_additional_hits", _explore_additional_hits); diff --git a/searchlib/src/vespa/searchlib/queryeval/nearest_neighbor_blueprint.h b/searchlib/src/vespa/searchlib/queryeval/nearest_neighbor_blueprint.h index 16b0e13014e..3be7d7fd01f 100644 --- a/searchlib/src/vespa/searchlib/queryeval/nearest_neighbor_blueprint.h +++ b/searchlib/src/vespa/searchlib/queryeval/nearest_neighbor_blueprint.h @@ -3,6 +3,7 @@ #include "blueprint.h" #include "nearest_neighbor_distance_heap.h" +#include <vespa/searchlib/tensor/distance_calculator.h> #include <vespa/searchlib/tensor/distance_function.h> #include <vespa/searchlib/tensor/nearest_neighbor_index.h> #include <optional> @@ -28,7 +29,8 @@ public: }; private: const tensor::ITensorAttribute& _attr_tensor; - std::unique_ptr<vespalib::eval::Value> _query_tensor; + search::tensor::DistanceCalculator _distance_calc; + const vespalib::eval::Value& _query_tensor; uint32_t _target_hits; uint32_t _adjusted_target_hits; bool _approximate; @@ -36,8 +38,6 @@ private: double _distance_threshold; double _global_filter_lower_limit; double _global_filter_upper_limit; - search::tensor::DistanceFunction::UP _fallback_dist_fun; - const search::tensor::DistanceFunction *_dist_fun; mutable NearestNeighborDistanceHeap _distance_heap; std::vector<search::tensor::NearestNeighborIndex::Neighbor> _found_hits; Algorithm _algorithm; @@ -59,7 +59,7 @@ public: NearestNeighborBlueprint& operator=(const NearestNeighborBlueprint&) = delete; ~NearestNeighborBlueprint(); const tensor::ITensorAttribute& get_attribute_tensor() const { return _attr_tensor; } - const vespalib::eval::Value& get_query_tensor() const { return *_query_tensor; } + const vespalib::eval::Value& get_query_tensor() const { return _query_tensor; } uint32_t get_target_hits() const { return _target_hits; } uint32_t get_adjusted_target_hits() const { return _adjusted_target_hits; } void set_global_filter(const GlobalFilter &global_filter, double estimated_hit_ratio) override; diff --git a/searchlib/src/vespa/searchlib/queryeval/nearest_neighbor_iterator.cpp b/searchlib/src/vespa/searchlib/queryeval/nearest_neighbor_iterator.cpp index 6a00568bd06..e06fcc614d8 100644 --- a/searchlib/src/vespa/searchlib/queryeval/nearest_neighbor_iterator.cpp +++ b/searchlib/src/vespa/searchlib/queryeval/nearest_neighbor_iterator.cpp @@ -2,6 +2,8 @@ #include "nearest_neighbor_iterator.h" #include <vespa/searchlib/common/bitvector.h> +#include <vespa/searchlib/tensor/distance_calculator.h> +#include <vespa/searchlib/tensor/distance_function.h> using search::tensor::ITensorAttribute; using vespalib::ConstArrayRef; @@ -34,11 +36,10 @@ public: NearestNeighborImpl(Params params_in) : NearestNeighborIterator(params_in), - _lhs(params().queryTensor.cells()), _lastScore(0.0) { - assert(is_compatible(params().tensorAttribute.getTensorType(), - params().queryTensor.type())); + assert(is_compatible(params().distance_calc.attribute_tensor().getTensorType(), + params().distance_calc.query_tensor().type())); } ~NearestNeighborImpl(); @@ -64,7 +65,7 @@ public: } void doUnpack(uint32_t docId) override { - double score = params().distanceFunction->to_rawscore(_lastScore); + double score = params().distance_calc.function().to_rawscore(_lastScore); params().tfmd.setRawScore(docId, score); params().distanceHeap.used(_lastScore); } @@ -73,11 +74,9 @@ public: private: double computeDistance(uint32_t docId, double limit) { - auto rhs = params().tensorAttribute.extract_cells_ref(docId); - return params().distanceFunction->calc_with_limit(_lhs, rhs, limit); + return params().distance_calc.calc_with_limit(docId, limit); } - TypedCells _lhs; double _lastScore; }; @@ -105,14 +104,12 @@ std::unique_ptr<NearestNeighborIterator> NearestNeighborIterator::create( bool strict, fef::TermFieldMatchData &tfmd, - const vespalib::eval::Value &queryTensor, - const search::tensor::ITensorAttribute &tensorAttribute, + const search::tensor::DistanceCalculator &distance_calc, NearestNeighborDistanceHeap &distanceHeap, - const search::BitVector *filter, - const search::tensor::DistanceFunction *dist_fun) + const search::BitVector *filter) { - Params params(tfmd, queryTensor, tensorAttribute, distanceHeap, filter, dist_fun); + Params params(tfmd, distance_calc, distanceHeap, filter); if (filter) { return resolve_strict<true>(strict, params); } else { diff --git a/searchlib/src/vespa/searchlib/queryeval/nearest_neighbor_iterator.h b/searchlib/src/vespa/searchlib/queryeval/nearest_neighbor_iterator.h index 66622288d84..0d8f70d15c2 100644 --- a/searchlib/src/vespa/searchlib/queryeval/nearest_neighbor_iterator.h +++ b/searchlib/src/vespa/searchlib/queryeval/nearest_neighbor_iterator.h @@ -7,10 +7,11 @@ #include <vespa/eval/eval/value.h> #include <vespa/searchlib/fef/termfieldmatchdata.h> #include <vespa/searchlib/tensor/i_tensor_attribute.h> -#include <vespa/searchlib/tensor/distance_function.h> #include <vespa/vespalib/util/priority_queue.h> #include <cmath> +namespace search::tensor { class DistanceCalculator; } + namespace search::queryeval { class NearestNeighborIterator : public SearchIterator @@ -21,24 +22,18 @@ public: struct Params { fef::TermFieldMatchData &tfmd; - const Value &queryTensor; - const ITensorAttribute &tensorAttribute; + const search::tensor::DistanceCalculator &distance_calc; NearestNeighborDistanceHeap &distanceHeap; const search::BitVector *filter; - const search::tensor::DistanceFunction *distanceFunction; - + Params(fef::TermFieldMatchData &tfmd_in, - const Value &queryTensor_in, - const ITensorAttribute &tensorAttribute_in, + const search::tensor::DistanceCalculator &distance_calc_in, NearestNeighborDistanceHeap &distanceHeap_in, - const search::BitVector *filter_in, - const search::tensor::DistanceFunction *distanceFunction_in) + const search::BitVector *filter_in) : tfmd(tfmd_in), - queryTensor(queryTensor_in), - tensorAttribute(tensorAttribute_in), + distance_calc(distance_calc_in), distanceHeap(distanceHeap_in), - filter(filter_in), - distanceFunction(distanceFunction_in) + filter(filter_in) {} }; @@ -49,11 +44,9 @@ public: static std::unique_ptr<NearestNeighborIterator> create( bool strict, fef::TermFieldMatchData &tfmd, - const Value &queryTensor, - const search::tensor::ITensorAttribute &tensorAttribute, + const search::tensor::DistanceCalculator &distance_calc, NearestNeighborDistanceHeap &distanceHeap, - const search::BitVector *filter, - const search::tensor::DistanceFunction *dist_fun); + const search::BitVector *filter); const Params& params() const { return _params; } private: diff --git a/searchlib/src/vespa/searchlib/queryeval/nns_index_iterator.cpp b/searchlib/src/vespa/searchlib/queryeval/nns_index_iterator.cpp index cd65f01025b..95264a79431 100644 --- a/searchlib/src/vespa/searchlib/queryeval/nns_index_iterator.cpp +++ b/searchlib/src/vespa/searchlib/queryeval/nns_index_iterator.cpp @@ -18,13 +18,13 @@ class NeighborVectorIterator : public NnsIndexIterator private: fef::TermFieldMatchData &_tfmd; const std::vector<Neighbor> &_hits; - const search::tensor::DistanceFunction * const _dist_fun; + const search::tensor::DistanceFunction &_dist_fun; uint32_t _idx; double _last_abstract_dist; public: NeighborVectorIterator(fef::TermFieldMatchData &tfmd, const std::vector<Neighbor> &hits, - const search::tensor::DistanceFunction *dist_fun) + const search::tensor::DistanceFunction &dist_fun) : _tfmd(tfmd), _hits(hits), _dist_fun(dist_fun), @@ -54,7 +54,7 @@ public: } void doUnpack(uint32_t docId) override { - double score = _dist_fun->to_rawscore(_last_abstract_dist); + double score = _dist_fun.to_rawscore(_last_abstract_dist); _tfmd.setRawScore(docId, score); } @@ -65,7 +65,7 @@ std::unique_ptr<NnsIndexIterator> NnsIndexIterator::create( fef::TermFieldMatchData &tfmd, const std::vector<Neighbor> &hits, - const search::tensor::DistanceFunction *dist_fun) + const search::tensor::DistanceFunction &dist_fun) { return std::make_unique<NeighborVectorIterator>(tfmd, hits, dist_fun); } diff --git a/searchlib/src/vespa/searchlib/queryeval/nns_index_iterator.h b/searchlib/src/vespa/searchlib/queryeval/nns_index_iterator.h index 019ac8579bd..031a603de49 100644 --- a/searchlib/src/vespa/searchlib/queryeval/nns_index_iterator.h +++ b/searchlib/src/vespa/searchlib/queryeval/nns_index_iterator.h @@ -16,7 +16,7 @@ public: static std::unique_ptr<NnsIndexIterator> create( fef::TermFieldMatchData &tfmd, const std::vector<Hit> &hits, - const search::tensor::DistanceFunction *dist_fun); + const search::tensor::DistanceFunction &dist_fun); }; } // namespace diff --git a/searchlib/src/vespa/searchlib/tensor/CMakeLists.txt b/searchlib/src/vespa/searchlib/tensor/CMakeLists.txt index ae34cdd66c8..9e0ccb8d37a 100644 --- a/searchlib/src/vespa/searchlib/tensor/CMakeLists.txt +++ b/searchlib/src/vespa/searchlib/tensor/CMakeLists.txt @@ -11,6 +11,7 @@ vespa_add_library(searchlib_tensor OBJECT direct_tensor_attribute.cpp direct_tensor_saver.cpp direct_tensor_store.cpp + distance_calculator.cpp distance_function_factory.cpp euclidean_distance.cpp geo_degrees_distance.cpp diff --git a/searchlib/src/vespa/searchlib/tensor/distance_calculator.cpp b/searchlib/src/vespa/searchlib/tensor/distance_calculator.cpp new file mode 100644 index 00000000000..6bb3d9ed49b --- /dev/null +++ b/searchlib/src/vespa/searchlib/tensor/distance_calculator.cpp @@ -0,0 +1,98 @@ +// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. + +#include "distance_calculator.h" +#include "distance_function_factory.h" +#include "i_tensor_attribute.h" +#include "nearest_neighbor_index.h" +#include <vespa/eval/eval/fast_value.h> + +using vespalib::eval::CellType; +using vespalib::eval::FastValueBuilderFactory; +using vespalib::eval::TypedCells; +using vespalib::eval::Value; +using vespalib::eval::ValueType; + +namespace { + +template<typename LCT, typename RCT> +std::unique_ptr<Value> +convert_cells(const ValueType& new_type, std::unique_ptr<Value> old_value) +{ + auto old_cells = old_value->cells().typify<LCT>(); + auto builder = FastValueBuilderFactory::get().create_value_builder<RCT>(new_type); + auto new_cells = builder->add_subspace(); + assert(old_cells.size() == new_cells.size()); + auto p = new_cells.begin(); + for (LCT value : old_cells) { + RCT conv(value); + *p++ = conv; + } + return builder->build(std::move(builder)); +} + +struct ConvertCellsSelector +{ + template <typename LCT, typename RCT> + static auto invoke(const ValueType& new_type, std::unique_ptr<Value> old_value) { + return convert_cells<LCT, RCT>(new_type, std::move(old_value)); + } + auto operator() (CellType from, CellType to, std::unique_ptr<Value> old_value) const { + using MyTypify = vespalib::eval::TypifyCellType; + ValueType new_type = old_value->type().cell_cast(to); + return vespalib::typify_invoke<2,MyTypify,ConvertCellsSelector>(from, to, new_type, std::move(old_value)); + } +}; + +} + +namespace search::tensor { + +DistanceCalculator::DistanceCalculator(const tensor::ITensorAttribute& attr_tensor, + std::unique_ptr<vespalib::eval::Value> query_tensor_in) + : _attr_tensor(attr_tensor), + _query_tensor_uptr(std::move(query_tensor_in)), + _query_tensor(), + _query_tensor_cells(), + _dist_fun_uptr(make_distance_function(_attr_tensor.distance_metric(), + _attr_tensor.getTensorType().cell_type())), + _dist_fun(_dist_fun_uptr.get()) +{ + assert(_dist_fun); + auto nns_index = _attr_tensor.nearest_neighbor_index(); + if (nns_index) { + _dist_fun = nns_index->distance_function(); + assert(_dist_fun); + } + auto query_ct = _query_tensor_uptr->cells().type; + CellType required_ct = _dist_fun->expected_cell_type(); + if (query_ct != required_ct) { + ConvertCellsSelector converter; + _query_tensor_uptr = converter(query_ct, required_ct, std::move(_query_tensor_uptr)); + } + _query_tensor = _query_tensor_uptr.get(); + _query_tensor_cells = _query_tensor->cells(); +} + +DistanceCalculator::DistanceCalculator(const tensor::ITensorAttribute& attr_tensor, + const vespalib::eval::Value& query_tensor_in, + const DistanceFunction& function_in) + : _attr_tensor(attr_tensor), + _query_tensor_uptr(), + _query_tensor(&query_tensor_in), + _query_tensor_cells(_query_tensor->cells()), + _dist_fun_uptr(), + _dist_fun(&function_in) +{ +} + +DistanceCalculator::~DistanceCalculator() = default; + +double +DistanceCalculator::calc_with_limit(uint32_t docid, double limit) const +{ + auto rhs = _attr_tensor.extract_cells_ref(docid); + return _dist_fun->calc_with_limit(_query_tensor_cells, rhs, limit); +} + +} + diff --git a/searchlib/src/vespa/searchlib/tensor/distance_calculator.h b/searchlib/src/vespa/searchlib/tensor/distance_calculator.h new file mode 100644 index 00000000000..df9344a24d1 --- /dev/null +++ b/searchlib/src/vespa/searchlib/tensor/distance_calculator.h @@ -0,0 +1,49 @@ +// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +#pragma once + +#include <vespa/eval/eval/typed_cells.h> +#include <memory> + +namespace vespalib::eval { struct Value; } + +namespace search::tensor { + +class DistanceFunction; +class ITensorAttribute; + +/** + * Class used to calculate the distance between two n-dimensional vectors, + * where one is stored in a TensorAttribute and the other comes from the query. + * + * The distance function to use is defined in the TensorAttribute. + */ +class DistanceCalculator { +private: + const tensor::ITensorAttribute& _attr_tensor; + std::unique_ptr<vespalib::eval::Value> _query_tensor_uptr; + const vespalib::eval::Value* _query_tensor; + vespalib::eval::TypedCells _query_tensor_cells; + std::unique_ptr<DistanceFunction> _dist_fun_uptr; + const DistanceFunction* _dist_fun; + +public: + DistanceCalculator(const tensor::ITensorAttribute& attr_tensor, + std::unique_ptr<vespalib::eval::Value> query_tensor_in); + + /** + * Only used by unit tests where ownership of query tensor and distance function is handled outside. + */ + DistanceCalculator(const tensor::ITensorAttribute& attr_tensor, + const vespalib::eval::Value& query_tensor_in, + const DistanceFunction& function_in); + + ~DistanceCalculator(); + + const tensor::ITensorAttribute& attribute_tensor() const { return _attr_tensor; } + const vespalib::eval::Value& query_tensor() const { return *_query_tensor; } + const DistanceFunction& function() const { return *_dist_fun; } + + double calc_with_limit(uint32_t docid, double limit) const; +}; + +} |