diff options
author | Arne Juul <arnej@yahooinc.com> | 2023-04-19 13:41:16 +0000 |
---|---|---|
committer | Arne Juul <arnej@yahooinc.com> | 2023-04-20 09:30:24 +0000 |
commit | 4e3fd9eeebeb403d4ad23bf70470d895cbdfbd1c (patch) | |
tree | 6a24edf1fd2f0cdb575a94e1f86dfcda17bcc450 /searchlib/src | |
parent | d525287130a57e0de1e0f89332d8bbf67481e528 (diff) |
prepare for more advanced BoundDistanceFunction
Diffstat (limited to 'searchlib/src')
-rw-r--r-- | searchlib/src/vespa/searchlib/tensor/bound_distance_function.cpp | 55 | ||||
-rw-r--r-- | searchlib/src/vespa/searchlib/tensor/bound_distance_function.h | 47 |
2 files changed, 102 insertions, 0 deletions
diff --git a/searchlib/src/vespa/searchlib/tensor/bound_distance_function.cpp b/searchlib/src/vespa/searchlib/tensor/bound_distance_function.cpp index 33b94e5218c..56edbf9fede 100644 --- a/searchlib/src/vespa/searchlib/tensor/bound_distance_function.cpp +++ b/searchlib/src/vespa/searchlib/tensor/bound_distance_function.cpp @@ -1,3 +1,58 @@ // Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. #include "bound_distance_function.h" +#include <vespa/log/log.h> + +LOG_SETUP(".searchlib.tensor.bound_distance_function"); + +using vespalib::ConstArrayRef; +using vespalib::ArrayRef; +using vespalib::eval::CellType; +using vespalib::eval::TypedCells; + +namespace search::tensor { + +namespace { + +template<typename FromType, typename ToType> +ConstArrayRef<ToType> +convert_cells(ArrayRef<ToType> space, TypedCells cells) +{ + assert(cells.size == space.size()); + auto old_cells = cells.typify<FromType>(); + ToType *p = space.data(); + for (FromType value : old_cells) { + ToType conv(value); + *p++ = conv; + } + return space; +} + +template <typename ToType> +struct ConvertCellsSelector +{ + template <typename FromType> static auto invoke(ArrayRef<ToType> dst, TypedCells src) { + return convert_cells<FromType, ToType>(dst, src); + } +}; + +} // namespace + +template <typename FloatType> +ConstArrayRef<FloatType> +TemporaryVectorStore<FloatType>::internal_convert(TypedCells cells, size_t offset) { + LOG_ASSERT(cells.size * 2 == _tmpSpace.size()); + ArrayRef<FloatType> where(_tmpSpace.data() + offset, cells.size); + using MyTypify = vespalib::eval::TypifyCellType; + using MySelector = ConvertCellsSelector<FloatType>; + ConstArrayRef<FloatType> result = vespalib::typify_invoke<1,MyTypify,MySelector>(cells.type, where, cells); + return result; +} + +template class TemporaryVectorStore<float>; +template class TemporaryVectorStore<double>; + +template class ConvertingBoundDistance<float>; +template class ConvertingBoundDistance<double>; + +} diff --git a/searchlib/src/vespa/searchlib/tensor/bound_distance_function.h b/searchlib/src/vespa/searchlib/tensor/bound_distance_function.h index 17e9e49cada..5d310dec32d 100644 --- a/searchlib/src/vespa/searchlib/tensor/bound_distance_function.h +++ b/searchlib/src/vespa/searchlib/tensor/bound_distance_function.h @@ -41,4 +41,51 @@ public: double limit) const = 0; }; + +/** helper class - temporary storage of possibly-converted vector cells */ +template <typename FloatType> +class TemporaryVectorStore { +private: + vespalib::Array<FloatType> _tmpSpace; + vespalib::ConstArrayRef<FloatType> internal_convert(vespalib::eval::TypedCells cells, size_t offset); +public: + TemporaryVectorStore(size_t vectorSize) : _tmpSpace(vectorSize * 2) {} + vespalib::ConstArrayRef<FloatType> storeLhs(vespalib::eval::TypedCells cells) { + return internal_convert(cells, 0); + } + vespalib::ConstArrayRef<FloatType> convertRhs(vespalib::eval::TypedCells cells) { + if (vespalib::eval::get_cell_type<FloatType>() == cells.type) [[likely]] { + return cells.unsafe_typify<FloatType>(); + } else { + return internal_convert(cells, cells.size); + } + } +}; + +template<typename FloatType> +class ConvertingBoundDistance : public BoundDistanceFunction { + mutable TemporaryVectorStore<FloatType> _tmpSpace; + const vespalib::eval::TypedCells _lhs; + const DistanceFunction &_df; +public: + ConvertingBoundDistance(const vespalib::eval::TypedCells& lhs, const DistanceFunction &df) + : BoundDistanceFunction(vespalib::eval::get_cell_type<FloatType>()), + _tmpSpace(lhs.size), + _lhs(_tmpSpace.storeLhs(lhs)), + _df(df) + {} + double calc(const vespalib::eval::TypedCells& rhs) const override { + return _df.calc(_lhs, vespalib::eval::TypedCells(_tmpSpace.convertRhs(rhs))); + } + double convert_threshold(double threshold) const override { + return _df.convert_threshold(threshold); + } + double to_rawscore(double distance) const override { + return _df.to_rawscore(distance); + } + double calc_with_limit(const vespalib::eval::TypedCells& rhs, double limit) const override { + return _df.calc_with_limit(_lhs, vespalib::eval::TypedCells(_tmpSpace.convertRhs(rhs)), limit); + } +}; + } |