summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorArne Juul <arnej@yahooinc.com>2023-04-20 13:26:02 +0000
committerArne Juul <arnej@yahooinc.com>2023-04-24 07:58:38 +0000
commit06ae8f17e040fe27e19dd7bdf4857ce0c4ccaba1 (patch)
tree95635dda48f7f188f87e4b62ee2852a7cfb652c3
parentbc7e1b8c1a96a49838e55a7b97c4de565abc9ad3 (diff)
also add BoundEuclideanDistance
-rw-r--r--searchlib/src/tests/tensor/distance_functions/distance_functions_test.cpp82
-rw-r--r--searchlib/src/vespa/searchlib/tensor/bound_distance_function.cpp1
-rw-r--r--searchlib/src/vespa/searchlib/tensor/bound_distance_function.h3
-rw-r--r--searchlib/src/vespa/searchlib/tensor/distance_function_factory.cpp7
-rw-r--r--searchlib/src/vespa/searchlib/tensor/euclidean_distance.cpp69
-rw-r--r--searchlib/src/vespa/searchlib/tensor/euclidean_distance.h12
6 files changed, 163 insertions, 11 deletions
diff --git a/searchlib/src/tests/tensor/distance_functions/distance_functions_test.cpp b/searchlib/src/tests/tensor/distance_functions/distance_functions_test.cpp
index 86b83b2c651..ae283f3f2b2 100644
--- a/searchlib/src/tests/tensor/distance_functions/distance_functions_test.cpp
+++ b/searchlib/src/tests/tensor/distance_functions/distance_functions_test.cpp
@@ -44,6 +44,30 @@ void verify_geo_miles(const DistanceFunction *dist_fun,
}
}
+double computeEuclideanChecked(TypedCells a, TypedCells b) {
+ static EuclideanDistanceFunctionFactory<Int8Float> i8f_dff;
+ static EuclideanDistanceFunctionFactory<float> flt_dff;
+ static EuclideanDistanceFunctionFactory<double> dbl_dff;
+ auto d_n = dbl_dff.for_query_vector(a);
+ auto d_f = flt_dff.for_query_vector(a);
+ auto d_r = dbl_dff.for_query_vector(b);
+ auto d_i = dbl_dff.for_insertion_vector(a);
+ // normal:
+ double result = d_n->calc(b);
+ // insert is exactly same:
+ EXPECT_EQ(d_i->calc(b), result);
+ // reverse:
+ EXPECT_DOUBLE_EQ(d_r->calc(a), result);
+ // float factory:
+ EXPECT_FLOAT_EQ(d_f->calc(b), result);
+ if (a.type == vespalib::eval::CellType::INT8 ||
+ b.type == vespalib::eval::CellType::INT8)
+ {
+ auto d_8 = i8f_dff.for_query_vector(a);
+ EXPECT_DOUBLE_EQ(d_8->calc(b), result);
+ }
+ return result;
+}
TEST(DistanceFunctionsTest, euclidean_gives_expected_score)
{
@@ -59,15 +83,56 @@ TEST(DistanceFunctionsTest, euclidean_gives_expected_score)
std::vector<double> p5{0.0,-1.0, 0.0};
std::vector<double> p6{1.0, 2.0, 2.0};
- double n4 = euclid->calc(t(p0), t(p4));
+ double n4 = computeEuclideanChecked(t(p0), t(p4));
EXPECT_FLOAT_EQ(n4, 1.0);
- double d12 = euclid->calc(t(p1), t(p2));
+ double d12 = computeEuclideanChecked(t(p1), t(p2));
EXPECT_EQ(d12, 2.0);
EXPECT_DOUBLE_EQ(euclid->to_rawscore(d12), 1.0/(1.0 + sqrt(2.0)));
double threshold = euclid->convert_threshold(8.0);
EXPECT_EQ(threshold, 64.0);
threshold = euclid->convert_threshold(0.5);
EXPECT_EQ(threshold, 0.25);
+
+ // simple hand-checked distances:
+ EXPECT_EQ(computeEuclideanChecked(t(p0), t(p0)), 0.0);
+ EXPECT_EQ(computeEuclideanChecked(t(p0), t(p1)), 1.0);
+ EXPECT_EQ(computeEuclideanChecked(t(p0), t(p2)), 1.0);
+ EXPECT_EQ(computeEuclideanChecked(t(p0), t(p3)), 1.0);
+ EXPECT_EQ(computeEuclideanChecked(t(p0), t(p5)), 1.0);
+ EXPECT_EQ(computeEuclideanChecked(t(p0), t(p6)), 9.0);
+
+ EXPECT_EQ(computeEuclideanChecked(t(p1), t(p1)), 0.0);
+ EXPECT_EQ(computeEuclideanChecked(t(p1), t(p2)), 2.0);
+ EXPECT_EQ(computeEuclideanChecked(t(p1), t(p3)), 2.0);
+ EXPECT_EQ(computeEuclideanChecked(t(p1), t(p5)), 2.0);
+ EXPECT_EQ(computeEuclideanChecked(t(p1), t(p6)), 8.0);
+
+ EXPECT_EQ(computeEuclideanChecked(t(p2), t(p2)), 0.0);
+ EXPECT_EQ(computeEuclideanChecked(t(p2), t(p3)), 2.0);
+ EXPECT_EQ(computeEuclideanChecked(t(p2), t(p5)), 4.0);
+ EXPECT_EQ(computeEuclideanChecked(t(p2), t(p6)), 6.0);
+
+ EXPECT_EQ(computeEuclideanChecked(t(p3), t(p3)), 0.0);
+ EXPECT_EQ(computeEuclideanChecked(t(p3), t(p5)), 2.0);
+ EXPECT_EQ(computeEuclideanChecked(t(p3), t(p6)), 6.0);
+
+ EXPECT_EQ(computeEuclideanChecked(t(p5), t(p5)), 0.0);
+ EXPECT_EQ(computeEuclideanChecked(t(p5), t(p6)), 14.0);
+
+ EXPECT_EQ(computeEuclideanChecked(t(p6), t(p6)), 0.0);
+
+ // smoke test for bfloat16:
+ std::vector<vespalib::BFloat16> bf16v;
+ bf16v.emplace_back(1.0);
+ bf16v.emplace_back(1.0);
+ bf16v.emplace_back(1.0);
+ EXPECT_EQ(computeEuclideanChecked(t(bf16v), t(p0)), 3.0);
+ EXPECT_EQ(computeEuclideanChecked(t(bf16v), t(p1)), 2.0);
+ EXPECT_EQ(computeEuclideanChecked(t(bf16v), t(p2)), 2.0);
+ EXPECT_EQ(computeEuclideanChecked(t(bf16v), t(p3)), 2.0);
+ EXPECT_FLOAT_EQ(computeEuclideanChecked(t(bf16v), t(p4)), 0.5857863);
+ EXPECT_EQ(computeEuclideanChecked(t(bf16v), t(p5)), 6.0);
+ EXPECT_EQ(computeEuclideanChecked(t(bf16v), t(p6)), 2.0);
}
TEST(DistanceFunctionsTest, euclidean_int8_smoketest)
@@ -81,14 +146,13 @@ TEST(DistanceFunctionsTest, euclidean_int8_smoketest)
std::vector<Int8Float> p5{0.0,-1.0, 0.0};
std::vector<Int8Float> p7{-1.0, 2.0, -2.0};
- EXPECT_DOUBLE_EQ(1.0, euclid->calc(t(p0), t(p1)));
- EXPECT_DOUBLE_EQ(1.0, euclid->calc(t(p0), t(p5)));
- EXPECT_DOUBLE_EQ(9.0, euclid->calc(t(p0), t(p7)));
-
- EXPECT_DOUBLE_EQ(2.0, euclid->calc(t(p1), t(p5)));
- EXPECT_DOUBLE_EQ(12.0, euclid->calc(t(p1), t(p7)));
- EXPECT_DOUBLE_EQ(14.0, euclid->calc(t(p5), t(p7)));
+ EXPECT_DOUBLE_EQ(1.0, computeEuclideanChecked(t(p0), t(p1)));
+ EXPECT_DOUBLE_EQ(1.0, computeEuclideanChecked(t(p0), t(p5)));
+ EXPECT_DOUBLE_EQ(9.0, computeEuclideanChecked(t(p0), t(p7)));
+ EXPECT_DOUBLE_EQ(2.0, computeEuclideanChecked(t(p1), t(p5)));
+ EXPECT_DOUBLE_EQ(12.0, computeEuclideanChecked(t(p1), t(p7)));
+ EXPECT_DOUBLE_EQ(14.0, computeEuclideanChecked(t(p5), t(p7)));
}
double computeAngularChecked(TypedCells a, TypedCells b) {
diff --git a/searchlib/src/vespa/searchlib/tensor/bound_distance_function.cpp b/searchlib/src/vespa/searchlib/tensor/bound_distance_function.cpp
index 56edbf9fede..19c2e744954 100644
--- a/searchlib/src/vespa/searchlib/tensor/bound_distance_function.cpp
+++ b/searchlib/src/vespa/searchlib/tensor/bound_distance_function.cpp
@@ -49,6 +49,7 @@ TemporaryVectorStore<FloatType>::internal_convert(TypedCells cells, size_t offse
return result;
}
+template class TemporaryVectorStore<vespalib::eval::Int8Float>;
template class TemporaryVectorStore<float>;
template class TemporaryVectorStore<double>;
diff --git a/searchlib/src/vespa/searchlib/tensor/bound_distance_function.h b/searchlib/src/vespa/searchlib/tensor/bound_distance_function.h
index 0f51e8a33ef..8949aea8796 100644
--- a/searchlib/src/vespa/searchlib/tensor/bound_distance_function.h
+++ b/searchlib/src/vespa/searchlib/tensor/bound_distance_function.h
@@ -5,7 +5,6 @@
#include <memory>
#include <vespa/eval/eval/cell_type.h>
#include <vespa/eval/eval/typed_cells.h>
-#include <vespa/vespalib/util/array.h>
#include <vespa/vespalib/util/arrayref.h>
#include "distance_function.h"
@@ -48,7 +47,7 @@ public:
template <typename FloatType>
class TemporaryVectorStore {
private:
- vespalib::Array<FloatType> _tmpSpace;
+ std::vector<FloatType> _tmpSpace;
vespalib::ConstArrayRef<FloatType> internal_convert(vespalib::eval::TypedCells cells, size_t offset);
public:
TemporaryVectorStore(size_t vectorSize) : _tmpSpace(vectorSize * 2) {}
diff --git a/searchlib/src/vespa/searchlib/tensor/distance_function_factory.cpp b/searchlib/src/vespa/searchlib/tensor/distance_function_factory.cpp
index 7ccca655943..cca492ef212 100644
--- a/searchlib/src/vespa/searchlib/tensor/distance_function_factory.cpp
+++ b/searchlib/src/vespa/searchlib/tensor/distance_function_factory.cpp
@@ -100,6 +100,13 @@ make_distance_function_factory(search::attribute::DistanceMetric variant,
}
return std::make_unique<AngularDistanceFunctionFactory<float>>();
}
+ if (variant == DistanceMetric::Euclidean) {
+ switch (cell_type) {
+ case CellType::DOUBLE: return std::make_unique<EuclideanDistanceFunctionFactory<double>>();
+ case CellType::INT8: return std::make_unique<EuclideanDistanceFunctionFactory<vespalib::eval::Int8Float>>();
+ default: return std::make_unique<EuclideanDistanceFunctionFactory<float>>();
+ }
+ }
auto df = make_distance_function(variant, cell_type);
return std::make_unique<SimpleDistanceFunctionFactory>(std::move(df));
}
diff --git a/searchlib/src/vespa/searchlib/tensor/euclidean_distance.cpp b/searchlib/src/vespa/searchlib/tensor/euclidean_distance.cpp
index c83f1821321..6a54798883f 100644
--- a/searchlib/src/vespa/searchlib/tensor/euclidean_distance.cpp
+++ b/searchlib/src/vespa/searchlib/tensor/euclidean_distance.cpp
@@ -48,4 +48,73 @@ SquaredEuclideanDistance::calc_with_limit(const vespalib::eval::TypedCells& lhs,
template class SquaredEuclideanDistanceHW<float>;
template class SquaredEuclideanDistanceHW<double>;
+using vespalib::eval::Int8Float;
+
+template<typename FloatType>
+class BoundEuclideanDistance : public BoundDistanceFunction {
+private:
+ const vespalib::hwaccelrated::IAccelrated & _computer;
+ mutable TemporaryVectorStore<FloatType> _tmpSpace;
+ const vespalib::ConstArrayRef<FloatType> _lhs_vector;
+ static const double *cast(const double * p) { return p; }
+ static const float *cast(const float * p) { return p; }
+ static const int8_t *cast(const Int8Float * p) { return reinterpret_cast<const int8_t *>(p); }
+public:
+ BoundEuclideanDistance(const vespalib::eval::TypedCells& lhs)
+ : BoundDistanceFunction(vespalib::eval::get_cell_type<FloatType>()),
+ _computer(vespalib::hwaccelrated::IAccelrated::getAccelerator()),
+ _tmpSpace(lhs.size),
+ _lhs_vector(_tmpSpace.storeLhs(lhs))
+ {}
+ double calc(const vespalib::eval::TypedCells& rhs) const override {
+ size_t sz = _lhs_vector.size();
+ vespalib::ConstArrayRef<FloatType> rhs_vector = _tmpSpace.convertRhs(rhs);
+ assert(sz == rhs_vector.size());
+ auto a = &_lhs_vector[0];
+ auto b = &rhs_vector[0];
+ return _computer.squaredEuclideanDistance(cast(a), cast(b), sz);
+ }
+ double convert_threshold(double threshold) const override {
+ return threshold*threshold;
+ }
+ double to_rawscore(double distance) const override {
+ double d = sqrt(distance);
+ double score = 1.0 / (1.0 + d);
+ return score;
+ }
+ double calc_with_limit(const vespalib::eval::TypedCells& rhs, double limit) const override {
+ vespalib::ConstArrayRef<FloatType> rhs_vector = _tmpSpace.convertRhs(rhs);
+ double sum = 0.0;
+ size_t sz = _lhs_vector.size();
+ assert(sz == rhs_vector.size());
+ for (size_t i = 0; i < sz && sum <= limit; ++i) {
+ double diff = _lhs_vector[i] - rhs_vector[i];
+ sum += diff*diff;
+ }
+ return sum;
+ }
+};
+
+template class BoundEuclideanDistance<Int8Float>;
+template class BoundEuclideanDistance<float>;
+template class BoundEuclideanDistance<double>;
+
+template <typename FloatType>
+BoundDistanceFunction::UP
+EuclideanDistanceFunctionFactory<FloatType>::for_query_vector(const vespalib::eval::TypedCells& lhs) {
+ using DFT = BoundEuclideanDistance<FloatType>;
+ return std::make_unique<DFT>(lhs);
+}
+
+template <typename FloatType>
+BoundDistanceFunction::UP
+EuclideanDistanceFunctionFactory<FloatType>::for_insertion_vector(const vespalib::eval::TypedCells& lhs) {
+ using DFT = BoundEuclideanDistance<FloatType>;
+ return std::make_unique<DFT>(lhs);
+}
+
+template class EuclideanDistanceFunctionFactory<Int8Float>;
+template class EuclideanDistanceFunctionFactory<float>;
+template class EuclideanDistanceFunctionFactory<double>;
+
}
diff --git a/searchlib/src/vespa/searchlib/tensor/euclidean_distance.h b/searchlib/src/vespa/searchlib/tensor/euclidean_distance.h
index 6505ea119ea..b406f0d3d1a 100644
--- a/searchlib/src/vespa/searchlib/tensor/euclidean_distance.h
+++ b/searchlib/src/vespa/searchlib/tensor/euclidean_distance.h
@@ -3,6 +3,7 @@
#pragma once
#include "distance_function.h"
+#include "distance_function_factory.h"
#include <vespa/eval/eval/typed_cells.h>
#include <vespa/vespalib/hwaccelrated/iaccelrated.h>
#include <cmath>
@@ -78,4 +79,15 @@ private:
const vespalib::hwaccelrated::IAccelrated & _computer;
};
+
+template <typename FloatType>
+class EuclideanDistanceFunctionFactory : public DistanceFunctionFactory {
+public:
+ EuclideanDistanceFunctionFactory()
+ : DistanceFunctionFactory(vespalib::eval::get_cell_type<FloatType>())
+ {}
+ BoundDistanceFunction::UP for_query_vector(const vespalib::eval::TypedCells& lhs) override;
+ BoundDistanceFunction::UP for_insertion_vector(const vespalib::eval::TypedCells& lhs) override;
+};
+
}