aboutsummaryrefslogtreecommitdiffstats
path: root/searchlib/src
diff options
context:
space:
mode:
authorArne Juul <arnej@yahooinc.com>2023-04-19 13:41:16 +0000
committerArne Juul <arnej@yahooinc.com>2023-04-20 09:30:24 +0000
commit4e3fd9eeebeb403d4ad23bf70470d895cbdfbd1c (patch)
tree6a24edf1fd2f0cdb575a94e1f86dfcda17bcc450 /searchlib/src
parentd525287130a57e0de1e0f89332d8bbf67481e528 (diff)
prepare for more advanced BoundDistanceFunction
Diffstat (limited to 'searchlib/src')
-rw-r--r--searchlib/src/vespa/searchlib/tensor/bound_distance_function.cpp55
-rw-r--r--searchlib/src/vespa/searchlib/tensor/bound_distance_function.h47
2 files changed, 102 insertions, 0 deletions
diff --git a/searchlib/src/vespa/searchlib/tensor/bound_distance_function.cpp b/searchlib/src/vespa/searchlib/tensor/bound_distance_function.cpp
index 33b94e5218c..56edbf9fede 100644
--- a/searchlib/src/vespa/searchlib/tensor/bound_distance_function.cpp
+++ b/searchlib/src/vespa/searchlib/tensor/bound_distance_function.cpp
@@ -1,3 +1,58 @@
// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
#include "bound_distance_function.h"
+#include <vespa/log/log.h>
+
+LOG_SETUP(".searchlib.tensor.bound_distance_function");
+
+using vespalib::ConstArrayRef;
+using vespalib::ArrayRef;
+using vespalib::eval::CellType;
+using vespalib::eval::TypedCells;
+
+namespace search::tensor {
+
+namespace {
+
+template<typename FromType, typename ToType>
+ConstArrayRef<ToType>
+convert_cells(ArrayRef<ToType> space, TypedCells cells)
+{
+ assert(cells.size == space.size());
+ auto old_cells = cells.typify<FromType>();
+ ToType *p = space.data();
+ for (FromType value : old_cells) {
+ ToType conv(value);
+ *p++ = conv;
+ }
+ return space;
+}
+
+template <typename ToType>
+struct ConvertCellsSelector
+{
+ template <typename FromType> static auto invoke(ArrayRef<ToType> dst, TypedCells src) {
+ return convert_cells<FromType, ToType>(dst, src);
+ }
+};
+
+} // namespace
+
+template <typename FloatType>
+ConstArrayRef<FloatType>
+TemporaryVectorStore<FloatType>::internal_convert(TypedCells cells, size_t offset) {
+ LOG_ASSERT(cells.size * 2 == _tmpSpace.size());
+ ArrayRef<FloatType> where(_tmpSpace.data() + offset, cells.size);
+ using MyTypify = vespalib::eval::TypifyCellType;
+ using MySelector = ConvertCellsSelector<FloatType>;
+ ConstArrayRef<FloatType> result = vespalib::typify_invoke<1,MyTypify,MySelector>(cells.type, where, cells);
+ return result;
+}
+
+template class TemporaryVectorStore<float>;
+template class TemporaryVectorStore<double>;
+
+template class ConvertingBoundDistance<float>;
+template class ConvertingBoundDistance<double>;
+
+}
diff --git a/searchlib/src/vespa/searchlib/tensor/bound_distance_function.h b/searchlib/src/vespa/searchlib/tensor/bound_distance_function.h
index 17e9e49cada..5d310dec32d 100644
--- a/searchlib/src/vespa/searchlib/tensor/bound_distance_function.h
+++ b/searchlib/src/vespa/searchlib/tensor/bound_distance_function.h
@@ -41,4 +41,51 @@ public:
double limit) const = 0;
};
+
+/** helper class - temporary storage of possibly-converted vector cells */
+template <typename FloatType>
+class TemporaryVectorStore {
+private:
+ vespalib::Array<FloatType> _tmpSpace;
+ vespalib::ConstArrayRef<FloatType> internal_convert(vespalib::eval::TypedCells cells, size_t offset);
+public:
+ TemporaryVectorStore(size_t vectorSize) : _tmpSpace(vectorSize * 2) {}
+ vespalib::ConstArrayRef<FloatType> storeLhs(vespalib::eval::TypedCells cells) {
+ return internal_convert(cells, 0);
+ }
+ vespalib::ConstArrayRef<FloatType> convertRhs(vespalib::eval::TypedCells cells) {
+ if (vespalib::eval::get_cell_type<FloatType>() == cells.type) [[likely]] {
+ return cells.unsafe_typify<FloatType>();
+ } else {
+ return internal_convert(cells, cells.size);
+ }
+ }
+};
+
+template<typename FloatType>
+class ConvertingBoundDistance : public BoundDistanceFunction {
+ mutable TemporaryVectorStore<FloatType> _tmpSpace;
+ const vespalib::eval::TypedCells _lhs;
+ const DistanceFunction &_df;
+public:
+ ConvertingBoundDistance(const vespalib::eval::TypedCells& lhs, const DistanceFunction &df)
+ : BoundDistanceFunction(vespalib::eval::get_cell_type<FloatType>()),
+ _tmpSpace(lhs.size),
+ _lhs(_tmpSpace.storeLhs(lhs)),
+ _df(df)
+ {}
+ double calc(const vespalib::eval::TypedCells& rhs) const override {
+ return _df.calc(_lhs, vespalib::eval::TypedCells(_tmpSpace.convertRhs(rhs)));
+ }
+ double convert_threshold(double threshold) const override {
+ return _df.convert_threshold(threshold);
+ }
+ double to_rawscore(double distance) const override {
+ return _df.to_rawscore(distance);
+ }
+ double calc_with_limit(const vespalib::eval::TypedCells& rhs, double limit) const override {
+ return _df.calc_with_limit(_lhs, vespalib::eval::TypedCells(_tmpSpace.convertRhs(rhs)), limit);
+ }
+};
+
}