summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorGeir Storli <geirst@yahooinc.com>2023-04-24 17:12:04 +0200
committerGitHub <noreply@github.com>2023-04-24 17:12:04 +0200
commit8c52052c3150c209db37a9a2747c54ccb7d4e171 (patch)
tree228c711456d9c17e7d51d34fd07c51872340fc3f
parentd54798f7eb271a99e80fb7aa2e7487fa148d5166 (diff)
parent46dcacf8b4402f349411171588e1485c2e58569f (diff)
Merge pull request #26829 from vespa-engine/arnej/bound-euclidean-distance
Arnej/bound euclidean distance
-rw-r--r--searchlib/src/tests/tensor/distance_functions/distance_functions_test.cpp82
-rw-r--r--searchlib/src/vespa/searchlib/tensor/CMakeLists.txt1
-rw-r--r--searchlib/src/vespa/searchlib/tensor/angular_distance.cpp1
-rw-r--r--searchlib/src/vespa/searchlib/tensor/bound_distance_function.cpp55
-rw-r--r--searchlib/src/vespa/searchlib/tensor/bound_distance_function.h48
-rw-r--r--searchlib/src/vespa/searchlib/tensor/distance_function_factory.cpp7
-rw-r--r--searchlib/src/vespa/searchlib/tensor/euclidean_distance.cpp70
-rw-r--r--searchlib/src/vespa/searchlib/tensor/euclidean_distance.h12
-rw-r--r--searchlib/src/vespa/searchlib/tensor/temporary_vector_store.cpp57
-rw-r--r--searchlib/src/vespa/searchlib/tensor/temporary_vector_store.h32
10 files changed, 253 insertions, 112 deletions
diff --git a/searchlib/src/tests/tensor/distance_functions/distance_functions_test.cpp b/searchlib/src/tests/tensor/distance_functions/distance_functions_test.cpp
index 86b83b2c651..ae283f3f2b2 100644
--- a/searchlib/src/tests/tensor/distance_functions/distance_functions_test.cpp
+++ b/searchlib/src/tests/tensor/distance_functions/distance_functions_test.cpp
@@ -44,6 +44,30 @@ void verify_geo_miles(const DistanceFunction *dist_fun,
}
}
+double computeEuclideanChecked(TypedCells a, TypedCells b) {
+ static EuclideanDistanceFunctionFactory<Int8Float> i8f_dff;
+ static EuclideanDistanceFunctionFactory<float> flt_dff;
+ static EuclideanDistanceFunctionFactory<double> dbl_dff;
+ auto d_n = dbl_dff.for_query_vector(a);
+ auto d_f = flt_dff.for_query_vector(a);
+ auto d_r = dbl_dff.for_query_vector(b);
+ auto d_i = dbl_dff.for_insertion_vector(a);
+ // normal:
+ double result = d_n->calc(b);
+ // insert is exactly same:
+ EXPECT_EQ(d_i->calc(b), result);
+ // reverse:
+ EXPECT_DOUBLE_EQ(d_r->calc(a), result);
+ // float factory:
+ EXPECT_FLOAT_EQ(d_f->calc(b), result);
+ if (a.type == vespalib::eval::CellType::INT8 ||
+ b.type == vespalib::eval::CellType::INT8)
+ {
+ auto d_8 = i8f_dff.for_query_vector(a);
+ EXPECT_DOUBLE_EQ(d_8->calc(b), result);
+ }
+ return result;
+}
TEST(DistanceFunctionsTest, euclidean_gives_expected_score)
{
@@ -59,15 +83,56 @@ TEST(DistanceFunctionsTest, euclidean_gives_expected_score)
std::vector<double> p5{0.0,-1.0, 0.0};
std::vector<double> p6{1.0, 2.0, 2.0};
- double n4 = euclid->calc(t(p0), t(p4));
+ double n4 = computeEuclideanChecked(t(p0), t(p4));
EXPECT_FLOAT_EQ(n4, 1.0);
- double d12 = euclid->calc(t(p1), t(p2));
+ double d12 = computeEuclideanChecked(t(p1), t(p2));
EXPECT_EQ(d12, 2.0);
EXPECT_DOUBLE_EQ(euclid->to_rawscore(d12), 1.0/(1.0 + sqrt(2.0)));
double threshold = euclid->convert_threshold(8.0);
EXPECT_EQ(threshold, 64.0);
threshold = euclid->convert_threshold(0.5);
EXPECT_EQ(threshold, 0.25);
+
+ // simple hand-checked distances:
+ EXPECT_EQ(computeEuclideanChecked(t(p0), t(p0)), 0.0);
+ EXPECT_EQ(computeEuclideanChecked(t(p0), t(p1)), 1.0);
+ EXPECT_EQ(computeEuclideanChecked(t(p0), t(p2)), 1.0);
+ EXPECT_EQ(computeEuclideanChecked(t(p0), t(p3)), 1.0);
+ EXPECT_EQ(computeEuclideanChecked(t(p0), t(p5)), 1.0);
+ EXPECT_EQ(computeEuclideanChecked(t(p0), t(p6)), 9.0);
+
+ EXPECT_EQ(computeEuclideanChecked(t(p1), t(p1)), 0.0);
+ EXPECT_EQ(computeEuclideanChecked(t(p1), t(p2)), 2.0);
+ EXPECT_EQ(computeEuclideanChecked(t(p1), t(p3)), 2.0);
+ EXPECT_EQ(computeEuclideanChecked(t(p1), t(p5)), 2.0);
+ EXPECT_EQ(computeEuclideanChecked(t(p1), t(p6)), 8.0);
+
+ EXPECT_EQ(computeEuclideanChecked(t(p2), t(p2)), 0.0);
+ EXPECT_EQ(computeEuclideanChecked(t(p2), t(p3)), 2.0);
+ EXPECT_EQ(computeEuclideanChecked(t(p2), t(p5)), 4.0);
+ EXPECT_EQ(computeEuclideanChecked(t(p2), t(p6)), 6.0);
+
+ EXPECT_EQ(computeEuclideanChecked(t(p3), t(p3)), 0.0);
+ EXPECT_EQ(computeEuclideanChecked(t(p3), t(p5)), 2.0);
+ EXPECT_EQ(computeEuclideanChecked(t(p3), t(p6)), 6.0);
+
+ EXPECT_EQ(computeEuclideanChecked(t(p5), t(p5)), 0.0);
+ EXPECT_EQ(computeEuclideanChecked(t(p5), t(p6)), 14.0);
+
+ EXPECT_EQ(computeEuclideanChecked(t(p6), t(p6)), 0.0);
+
+ // smoke test for bfloat16:
+ std::vector<vespalib::BFloat16> bf16v;
+ bf16v.emplace_back(1.0);
+ bf16v.emplace_back(1.0);
+ bf16v.emplace_back(1.0);
+ EXPECT_EQ(computeEuclideanChecked(t(bf16v), t(p0)), 3.0);
+ EXPECT_EQ(computeEuclideanChecked(t(bf16v), t(p1)), 2.0);
+ EXPECT_EQ(computeEuclideanChecked(t(bf16v), t(p2)), 2.0);
+ EXPECT_EQ(computeEuclideanChecked(t(bf16v), t(p3)), 2.0);
+ EXPECT_FLOAT_EQ(computeEuclideanChecked(t(bf16v), t(p4)), 0.5857863);
+ EXPECT_EQ(computeEuclideanChecked(t(bf16v), t(p5)), 6.0);
+ EXPECT_EQ(computeEuclideanChecked(t(bf16v), t(p6)), 2.0);
}
TEST(DistanceFunctionsTest, euclidean_int8_smoketest)
@@ -81,14 +146,13 @@ TEST(DistanceFunctionsTest, euclidean_int8_smoketest)
std::vector<Int8Float> p5{0.0,-1.0, 0.0};
std::vector<Int8Float> p7{-1.0, 2.0, -2.0};
- EXPECT_DOUBLE_EQ(1.0, euclid->calc(t(p0), t(p1)));
- EXPECT_DOUBLE_EQ(1.0, euclid->calc(t(p0), t(p5)));
- EXPECT_DOUBLE_EQ(9.0, euclid->calc(t(p0), t(p7)));
-
- EXPECT_DOUBLE_EQ(2.0, euclid->calc(t(p1), t(p5)));
- EXPECT_DOUBLE_EQ(12.0, euclid->calc(t(p1), t(p7)));
- EXPECT_DOUBLE_EQ(14.0, euclid->calc(t(p5), t(p7)));
+ EXPECT_DOUBLE_EQ(1.0, computeEuclideanChecked(t(p0), t(p1)));
+ EXPECT_DOUBLE_EQ(1.0, computeEuclideanChecked(t(p0), t(p5)));
+ EXPECT_DOUBLE_EQ(9.0, computeEuclideanChecked(t(p0), t(p7)));
+ EXPECT_DOUBLE_EQ(2.0, computeEuclideanChecked(t(p1), t(p5)));
+ EXPECT_DOUBLE_EQ(12.0, computeEuclideanChecked(t(p1), t(p7)));
+ EXPECT_DOUBLE_EQ(14.0, computeEuclideanChecked(t(p5), t(p7)));
}
double computeAngularChecked(TypedCells a, TypedCells b) {
diff --git a/searchlib/src/vespa/searchlib/tensor/CMakeLists.txt b/searchlib/src/vespa/searchlib/tensor/CMakeLists.txt
index 090042e5b83..1783e0da1dd 100644
--- a/searchlib/src/vespa/searchlib/tensor/CMakeLists.txt
+++ b/searchlib/src/vespa/searchlib/tensor/CMakeLists.txt
@@ -34,6 +34,7 @@ vespa_add_library(searchlib_tensor OBJECT
serialized_tensor_ref.cpp
small_subspaces_buffer_type.cpp
subspace_type.cpp
+ temporary_vector_store.cpp
tensor_attribute.cpp
tensor_attribute_loader.cpp
tensor_attribute_saver.cpp
diff --git a/searchlib/src/vespa/searchlib/tensor/angular_distance.cpp b/searchlib/src/vespa/searchlib/tensor/angular_distance.cpp
index 5101373c047..85eac76728c 100644
--- a/searchlib/src/vespa/searchlib/tensor/angular_distance.cpp
+++ b/searchlib/src/vespa/searchlib/tensor/angular_distance.cpp
@@ -1,6 +1,7 @@
// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
#include "angular_distance.h"
+#include "temporary_vector_store.h"
using vespalib::typify_invoke;
using vespalib::eval::TypifyCellType;
diff --git a/searchlib/src/vespa/searchlib/tensor/bound_distance_function.cpp b/searchlib/src/vespa/searchlib/tensor/bound_distance_function.cpp
index 56edbf9fede..33b94e5218c 100644
--- a/searchlib/src/vespa/searchlib/tensor/bound_distance_function.cpp
+++ b/searchlib/src/vespa/searchlib/tensor/bound_distance_function.cpp
@@ -1,58 +1,3 @@
// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
#include "bound_distance_function.h"
-#include <vespa/log/log.h>
-
-LOG_SETUP(".searchlib.tensor.bound_distance_function");
-
-using vespalib::ConstArrayRef;
-using vespalib::ArrayRef;
-using vespalib::eval::CellType;
-using vespalib::eval::TypedCells;
-
-namespace search::tensor {
-
-namespace {
-
-template<typename FromType, typename ToType>
-ConstArrayRef<ToType>
-convert_cells(ArrayRef<ToType> space, TypedCells cells)
-{
- assert(cells.size == space.size());
- auto old_cells = cells.typify<FromType>();
- ToType *p = space.data();
- for (FromType value : old_cells) {
- ToType conv(value);
- *p++ = conv;
- }
- return space;
-}
-
-template <typename ToType>
-struct ConvertCellsSelector
-{
- template <typename FromType> static auto invoke(ArrayRef<ToType> dst, TypedCells src) {
- return convert_cells<FromType, ToType>(dst, src);
- }
-};
-
-} // namespace
-
-template <typename FloatType>
-ConstArrayRef<FloatType>
-TemporaryVectorStore<FloatType>::internal_convert(TypedCells cells, size_t offset) {
- LOG_ASSERT(cells.size * 2 == _tmpSpace.size());
- ArrayRef<FloatType> where(_tmpSpace.data() + offset, cells.size);
- using MyTypify = vespalib::eval::TypifyCellType;
- using MySelector = ConvertCellsSelector<FloatType>;
- ConstArrayRef<FloatType> result = vespalib::typify_invoke<1,MyTypify,MySelector>(cells.type, where, cells);
- return result;
-}
-
-template class TemporaryVectorStore<float>;
-template class TemporaryVectorStore<double>;
-
-template class ConvertingBoundDistance<float>;
-template class ConvertingBoundDistance<double>;
-
-}
diff --git a/searchlib/src/vespa/searchlib/tensor/bound_distance_function.h b/searchlib/src/vespa/searchlib/tensor/bound_distance_function.h
index 0f51e8a33ef..5d602a52227 100644
--- a/searchlib/src/vespa/searchlib/tensor/bound_distance_function.h
+++ b/searchlib/src/vespa/searchlib/tensor/bound_distance_function.h
@@ -5,7 +5,6 @@
#include <memory>
#include <vespa/eval/eval/cell_type.h>
#include <vespa/eval/eval/typed_cells.h>
-#include <vespa/vespalib/util/array.h>
#include <vespa/vespalib/util/arrayref.h>
#include "distance_function.h"
@@ -43,51 +42,4 @@ public:
double limit) const = 0;
};
-
-/** helper class - temporary storage of possibly-converted vector cells */
-template <typename FloatType>
-class TemporaryVectorStore {
-private:
- vespalib::Array<FloatType> _tmpSpace;
- vespalib::ConstArrayRef<FloatType> internal_convert(vespalib::eval::TypedCells cells, size_t offset);
-public:
- TemporaryVectorStore(size_t vectorSize) : _tmpSpace(vectorSize * 2) {}
- vespalib::ConstArrayRef<FloatType> storeLhs(vespalib::eval::TypedCells cells) {
- return internal_convert(cells, 0);
- }
- vespalib::ConstArrayRef<FloatType> convertRhs(vespalib::eval::TypedCells cells) {
- if (vespalib::eval::get_cell_type<FloatType>() == cells.type) [[likely]] {
- return cells.unsafe_typify<FloatType>();
- } else {
- return internal_convert(cells, cells.size);
- }
- }
-};
-
-template<typename FloatType>
-class ConvertingBoundDistance : public BoundDistanceFunction {
- mutable TemporaryVectorStore<FloatType> _tmpSpace;
- const vespalib::eval::TypedCells _lhs;
- const DistanceFunction &_df;
-public:
- ConvertingBoundDistance(const vespalib::eval::TypedCells& lhs, const DistanceFunction &df)
- : BoundDistanceFunction(vespalib::eval::get_cell_type<FloatType>()),
- _tmpSpace(lhs.size),
- _lhs(_tmpSpace.storeLhs(lhs)),
- _df(df)
- {}
- double calc(const vespalib::eval::TypedCells& rhs) const override {
- return _df.calc(_lhs, vespalib::eval::TypedCells(_tmpSpace.convertRhs(rhs)));
- }
- double convert_threshold(double threshold) const override {
- return _df.convert_threshold(threshold);
- }
- double to_rawscore(double distance) const override {
- return _df.to_rawscore(distance);
- }
- double calc_with_limit(const vespalib::eval::TypedCells& rhs, double limit) const override {
- return _df.calc_with_limit(_lhs, vespalib::eval::TypedCells(_tmpSpace.convertRhs(rhs)), limit);
- }
-};
-
}
diff --git a/searchlib/src/vespa/searchlib/tensor/distance_function_factory.cpp b/searchlib/src/vespa/searchlib/tensor/distance_function_factory.cpp
index 7ccca655943..cca492ef212 100644
--- a/searchlib/src/vespa/searchlib/tensor/distance_function_factory.cpp
+++ b/searchlib/src/vespa/searchlib/tensor/distance_function_factory.cpp
@@ -100,6 +100,13 @@ make_distance_function_factory(search::attribute::DistanceMetric variant,
}
return std::make_unique<AngularDistanceFunctionFactory<float>>();
}
+ if (variant == DistanceMetric::Euclidean) {
+ switch (cell_type) {
+ case CellType::DOUBLE: return std::make_unique<EuclideanDistanceFunctionFactory<double>>();
+ case CellType::INT8: return std::make_unique<EuclideanDistanceFunctionFactory<vespalib::eval::Int8Float>>();
+ default: return std::make_unique<EuclideanDistanceFunctionFactory<float>>();
+ }
+ }
auto df = make_distance_function(variant, cell_type);
return std::make_unique<SimpleDistanceFunctionFactory>(std::move(df));
}
diff --git a/searchlib/src/vespa/searchlib/tensor/euclidean_distance.cpp b/searchlib/src/vespa/searchlib/tensor/euclidean_distance.cpp
index c83f1821321..92d4e7af406 100644
--- a/searchlib/src/vespa/searchlib/tensor/euclidean_distance.cpp
+++ b/searchlib/src/vespa/searchlib/tensor/euclidean_distance.cpp
@@ -1,6 +1,7 @@
// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
#include "euclidean_distance.h"
+#include "temporary_vector_store.h"
using vespalib::typify_invoke;
using vespalib::eval::TypifyCellType;
@@ -48,4 +49,73 @@ SquaredEuclideanDistance::calc_with_limit(const vespalib::eval::TypedCells& lhs,
template class SquaredEuclideanDistanceHW<float>;
template class SquaredEuclideanDistanceHW<double>;
+using vespalib::eval::Int8Float;
+
+template<typename FloatType>
+class BoundEuclideanDistance : public BoundDistanceFunction {
+private:
+ const vespalib::hwaccelrated::IAccelrated & _computer;
+ mutable TemporaryVectorStore<FloatType> _tmpSpace;
+ const vespalib::ConstArrayRef<FloatType> _lhs_vector;
+ static const double *cast(const double * p) { return p; }
+ static const float *cast(const float * p) { return p; }
+ static const int8_t *cast(const Int8Float * p) { return reinterpret_cast<const int8_t *>(p); }
+public:
+ BoundEuclideanDistance(const vespalib::eval::TypedCells& lhs)
+ : BoundDistanceFunction(vespalib::eval::get_cell_type<FloatType>()),
+ _computer(vespalib::hwaccelrated::IAccelrated::getAccelerator()),
+ _tmpSpace(lhs.size),
+ _lhs_vector(_tmpSpace.storeLhs(lhs))
+ {}
+ double calc(const vespalib::eval::TypedCells& rhs) const override {
+ size_t sz = _lhs_vector.size();
+ vespalib::ConstArrayRef<FloatType> rhs_vector = _tmpSpace.convertRhs(rhs);
+ assert(sz == rhs_vector.size());
+ auto a = &_lhs_vector[0];
+ auto b = &rhs_vector[0];
+ return _computer.squaredEuclideanDistance(cast(a), cast(b), sz);
+ }
+ double convert_threshold(double threshold) const override {
+ return threshold*threshold;
+ }
+ double to_rawscore(double distance) const override {
+ double d = sqrt(distance);
+ double score = 1.0 / (1.0 + d);
+ return score;
+ }
+ double calc_with_limit(const vespalib::eval::TypedCells& rhs, double limit) const override {
+ vespalib::ConstArrayRef<FloatType> rhs_vector = _tmpSpace.convertRhs(rhs);
+ double sum = 0.0;
+ size_t sz = _lhs_vector.size();
+ assert(sz == rhs_vector.size());
+ for (size_t i = 0; i < sz && sum <= limit; ++i) {
+ double diff = _lhs_vector[i] - rhs_vector[i];
+ sum += diff*diff;
+ }
+ return sum;
+ }
+};
+
+template class BoundEuclideanDistance<Int8Float>;
+template class BoundEuclideanDistance<float>;
+template class BoundEuclideanDistance<double>;
+
+template <typename FloatType>
+BoundDistanceFunction::UP
+EuclideanDistanceFunctionFactory<FloatType>::for_query_vector(const vespalib::eval::TypedCells& lhs) {
+ using DFT = BoundEuclideanDistance<FloatType>;
+ return std::make_unique<DFT>(lhs);
+}
+
+template <typename FloatType>
+BoundDistanceFunction::UP
+EuclideanDistanceFunctionFactory<FloatType>::for_insertion_vector(const vespalib::eval::TypedCells& lhs) {
+ using DFT = BoundEuclideanDistance<FloatType>;
+ return std::make_unique<DFT>(lhs);
+}
+
+template class EuclideanDistanceFunctionFactory<Int8Float>;
+template class EuclideanDistanceFunctionFactory<float>;
+template class EuclideanDistanceFunctionFactory<double>;
+
}
diff --git a/searchlib/src/vespa/searchlib/tensor/euclidean_distance.h b/searchlib/src/vespa/searchlib/tensor/euclidean_distance.h
index 6505ea119ea..b406f0d3d1a 100644
--- a/searchlib/src/vespa/searchlib/tensor/euclidean_distance.h
+++ b/searchlib/src/vespa/searchlib/tensor/euclidean_distance.h
@@ -3,6 +3,7 @@
#pragma once
#include "distance_function.h"
+#include "distance_function_factory.h"
#include <vespa/eval/eval/typed_cells.h>
#include <vespa/vespalib/hwaccelrated/iaccelrated.h>
#include <cmath>
@@ -78,4 +79,15 @@ private:
const vespalib::hwaccelrated::IAccelrated & _computer;
};
+
+template <typename FloatType>
+class EuclideanDistanceFunctionFactory : public DistanceFunctionFactory {
+public:
+ EuclideanDistanceFunctionFactory()
+ : DistanceFunctionFactory(vespalib::eval::get_cell_type<FloatType>())
+ {}
+ BoundDistanceFunction::UP for_query_vector(const vespalib::eval::TypedCells& lhs) override;
+ BoundDistanceFunction::UP for_insertion_vector(const vespalib::eval::TypedCells& lhs) override;
+};
+
}
diff --git a/searchlib/src/vespa/searchlib/tensor/temporary_vector_store.cpp b/searchlib/src/vespa/searchlib/tensor/temporary_vector_store.cpp
new file mode 100644
index 00000000000..cc45f857d9f
--- /dev/null
+++ b/searchlib/src/vespa/searchlib/tensor/temporary_vector_store.cpp
@@ -0,0 +1,57 @@
+// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+
+#include "temporary_vector_store.h"
+
+#include <vespa/log/log.h>
+
+LOG_SETUP(".searchlib.tensor.temporary_vector_store");
+
+using vespalib::ConstArrayRef;
+using vespalib::ArrayRef;
+using vespalib::eval::CellType;
+using vespalib::eval::TypedCells;
+
+namespace search::tensor {
+
+namespace {
+
+template<typename FromType, typename ToType>
+ConstArrayRef<ToType>
+convert_cells(ArrayRef<ToType> space, TypedCells cells)
+{
+ assert(cells.size == space.size());
+ auto old_cells = cells.typify<FromType>();
+ ToType *p = space.data();
+ for (FromType value : old_cells) {
+ ToType conv(value);
+ *p++ = conv;
+ }
+ return space;
+}
+
+template <typename ToType>
+struct ConvertCellsSelector
+{
+ template <typename FromType> static auto invoke(ArrayRef<ToType> dst, TypedCells src) {
+ return convert_cells<FromType, ToType>(dst, src);
+ }
+};
+
+} // namespace
+
+template <typename FloatType>
+ConstArrayRef<FloatType>
+TemporaryVectorStore<FloatType>::internal_convert(TypedCells cells, size_t offset) {
+ LOG_ASSERT(cells.size * 2 == _tmpSpace.size());
+ ArrayRef<FloatType> where(_tmpSpace.data() + offset, cells.size);
+ using MyTypify = vespalib::eval::TypifyCellType;
+ using MySelector = ConvertCellsSelector<FloatType>;
+ ConstArrayRef<FloatType> result = vespalib::typify_invoke<1,MyTypify,MySelector>(cells.type, where, cells);
+ return result;
+}
+
+template class TemporaryVectorStore<vespalib::eval::Int8Float>;
+template class TemporaryVectorStore<float>;
+template class TemporaryVectorStore<double>;
+
+}
diff --git a/searchlib/src/vespa/searchlib/tensor/temporary_vector_store.h b/searchlib/src/vespa/searchlib/tensor/temporary_vector_store.h
new file mode 100644
index 00000000000..cd816621f91
--- /dev/null
+++ b/searchlib/src/vespa/searchlib/tensor/temporary_vector_store.h
@@ -0,0 +1,32 @@
+// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+
+#pragma once
+
+#include <memory>
+#include <vespa/eval/eval/cell_type.h>
+#include <vespa/eval/eval/typed_cells.h>
+#include <vespa/vespalib/util/arrayref.h>
+
+namespace search::tensor {
+
+/** helper class - temporary storage of possibly-converted vector cells */
+template <typename FloatType>
+class TemporaryVectorStore {
+private:
+ std::vector<FloatType> _tmpSpace;
+ vespalib::ConstArrayRef<FloatType> internal_convert(vespalib::eval::TypedCells cells, size_t offset);
+public:
+ TemporaryVectorStore(size_t vectorSize) : _tmpSpace(vectorSize * 2) {}
+ vespalib::ConstArrayRef<FloatType> storeLhs(vespalib::eval::TypedCells cells) {
+ return internal_convert(cells, 0);
+ }
+ vespalib::ConstArrayRef<FloatType> convertRhs(vespalib::eval::TypedCells cells) {
+ if (vespalib::eval::get_cell_type<FloatType>() == cells.type) [[likely]] {
+ return cells.unsafe_typify<FloatType>();
+ } else {
+ return internal_convert(cells, cells.size);
+ }
+ }
+};
+
+}