summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorArne Juul <arnej@yahooinc.com>2023-04-27 09:45:13 +0000
committerArne Juul <arnej@yahooinc.com>2023-04-28 09:15:42 +0000
commit4b33a265e217f3e242b0d32c8e0a720fc33352c4 (patch)
tree31ab4e9e1a3f9c10cb89940df254d8082d240041
parentefebd9ea2938466b0f6912365cf9bbb7b6253541 (diff)
add proof-of-concept for Maximum Inner Product Search
-rw-r--r--searchlib/src/vespa/searchlib/tensor/CMakeLists.txt1
-rw-r--r--searchlib/src/vespa/searchlib/tensor/mips_distance_transform.cpp90
-rw-r--r--searchlib/src/vespa/searchlib/tensor/mips_distance_transform.h39
3 files changed, 130 insertions, 0 deletions
diff --git a/searchlib/src/vespa/searchlib/tensor/CMakeLists.txt b/searchlib/src/vespa/searchlib/tensor/CMakeLists.txt
index 362bceaf0da..24c8db4863e 100644
--- a/searchlib/src/vespa/searchlib/tensor/CMakeLists.txt
+++ b/searchlib/src/vespa/searchlib/tensor/CMakeLists.txt
@@ -2,6 +2,7 @@
vespa_add_library(searchlib_tensor OBJECT
SOURCES
angular_distance.cpp
+ mips_distance_transform.cpp
bitvector_visited_tracker.cpp
bound_distance_function.cpp
default_nearest_neighbor_index_factory.cpp
diff --git a/searchlib/src/vespa/searchlib/tensor/mips_distance_transform.cpp b/searchlib/src/vespa/searchlib/tensor/mips_distance_transform.cpp
new file mode 100644
index 00000000000..fd2295bef76
--- /dev/null
+++ b/searchlib/src/vespa/searchlib/tensor/mips_distance_transform.cpp
@@ -0,0 +1,90 @@
+// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+
+#include "mips_distance_transform.h"
+#include "temporary_vector_store.h"
+#include <vespa/vespalib/hwaccelrated/iaccelrated.h>
+#include <cmath>
+#include <mutex>
+#include <variant>
+
+using vespalib::eval::Int8Float;
+
+namespace search::tensor {
+
+template<typename FloatType, bool extra_dim>
+class BoundMipsDistanceFunction : public BoundDistanceFunction {
+ mutable TemporaryVectorStore<FloatType> _tmpSpace;
+ const vespalib::ConstArrayRef<FloatType> _lhs_vector;
+ const vespalib::hwaccelrated::IAccelrated & _computer;
+ double _max_sq_norm;
+ using ExtraDimT = std::conditional<extra_dim,double,std::monostate>::type;
+ [[no_unique_address]] ExtraDimT _lhs_extra_dim;
+
+ static const double *cast(const double * p) { return p; }
+ static const float *cast(const float * p) { return p; }
+ static const int8_t *cast(const Int8Float * p) { return reinterpret_cast<const int8_t *>(p); }
+public:
+ BoundMipsDistanceFunction(const vespalib::eval::TypedCells& lhs, MaximumSquaredNormStore& sq_norm_store)
+ : BoundDistanceFunction(),
+ _tmpSpace(lhs.size),
+ _lhs_vector(_tmpSpace.storeLhs(lhs)),
+ _computer(vespalib::hwaccelrated::IAccelrated::getAccelerator())
+ {
+ const FloatType * a = _lhs_vector.data();
+ if constexpr (extra_dim) {
+ double lhs_sq_norm = _computer.dotProduct(cast(a), cast(a), lhs.size);
+ _max_sq_norm = sq_norm_store.get_max(lhs_sq_norm);
+ _lhs_extra_dim = std::sqrt(_max_sq_norm - lhs_sq_norm);
+ } else {
+ _max_sq_norm = sq_norm_store.get_max();
+ }
+ }
+
+ double get_extra_dim_value() requires extra_dim {
+ return _lhs_extra_dim;
+ }
+
+ double calc(const vespalib::eval::TypedCells &rhs) const override {
+ vespalib::ConstArrayRef<FloatType> rhs_vector = _tmpSpace.convertRhs(rhs);
+ const FloatType * a = _lhs_vector.data();
+ const FloatType * b = rhs_vector.data();
+ double dp = _computer.dotProduct(cast(a), cast(b), rhs.size);
+ if constexpr (extra_dim) {
+ double rhs_sq_norm = _computer.dotProduct(cast(b), cast(b), rhs.size);
+ double rhs_extra_dim = std::sqrt(_max_sq_norm - rhs_sq_norm);
+ dp += _lhs_extra_dim * rhs_extra_dim;
+ }
+ return -dp;
+ }
+ double convert_threshold(double threshold) const override {
+ return threshold;
+ }
+ double to_rawscore(double distance) const override {
+ double dp = -distance;
+ double t1 = dp / _max_sq_norm;
+ double t2 = t1 / (1.0 + std::fabs(t1));
+ double r = (t2 + 1.0) * 0.5;
+ return r;
+ }
+ double calc_with_limit(const vespalib::eval::TypedCells& rhs, double) const override {
+ return calc(rhs);
+ }
+};
+
+template<typename FloatType>
+BoundDistanceFunction::UP
+MipsDistanceFunctionFactory<FloatType>::for_query_vector(const vespalib::eval::TypedCells& lhs) {
+ return std::make_unique<BoundMipsDistanceFunction<FloatType, false>>(lhs, *_sq_norm_store);
+}
+
+template<typename FloatType>
+BoundDistanceFunction::UP
+MipsDistanceFunctionFactory<FloatType>::for_insertion_vector(const vespalib::eval::TypedCells& lhs) {
+ return std::make_unique<BoundMipsDistanceFunction<FloatType, true>>(lhs, *_sq_norm_store);
+};
+
+template class MipsDistanceFunctionFactory<Int8Float>;
+template class MipsDistanceFunctionFactory<float>;
+template class MipsDistanceFunctionFactory<double>;
+
+}
diff --git a/searchlib/src/vespa/searchlib/tensor/mips_distance_transform.h b/searchlib/src/vespa/searchlib/tensor/mips_distance_transform.h
new file mode 100644
index 00000000000..86dbe6e4d1e
--- /dev/null
+++ b/searchlib/src/vespa/searchlib/tensor/mips_distance_transform.h
@@ -0,0 +1,39 @@
+// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+
+#pragma once
+
+#include "distance_function.h"
+#include "distance_function_factory.h"
+#include <vespa/eval/eval/typed_cells.h>
+#include <mutex>
+#include <memory>
+
+namespace search::tensor {
+
+class MaximumSquaredNormStore {
+private:
+ std::mutex _lock;
+ double _max_sq_norm;
+public:
+ MaximumSquaredNormStore() noexcept : _lock(), _max_sq_norm(0.0) {}
+ double get_max(double value = 0.0) {
+ std::lock_guard<std::mutex> guard(_lock);
+ if (value > _max_sq_norm) [[unlikely]] {
+ _max_sq_norm = value;
+ }
+ return _max_sq_norm;
+ }
+};
+
+template<typename FloatType>
+class MipsDistanceFunctionFactory : public DistanceFunctionFactory {
+ std::shared_ptr<MaximumSquaredNormStore> _sq_norm_store;
+public:
+ MipsDistanceFunctionFactory() : _sq_norm_store(std::make_shared<MaximumSquaredNormStore>()) {}
+
+ BoundDistanceFunction::UP for_query_vector(const vespalib::eval::TypedCells& lhs) override;
+
+ BoundDistanceFunction::UP for_insertion_vector(const vespalib::eval::TypedCells& lhs) override;
+};
+
+}