summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorGeir Storli <geirst@yahooinc.com>2023-02-23 15:39:21 +0100
committerGitHub <noreply@github.com>2023-02-23 15:39:21 +0100
commit16e96352205ab3ae50393e831be848353db6ce47 (patch)
tree66553ea1b29c497929b6007a88e8848cd3411709
parentc6f31815922b0bc4444434ec292933d493ece623 (diff)
parent4d09377220c1c2d450b33455380cbcc00fd8a890 (diff)
Merge pull request #26163 from vespa-engine/toregge/extend-distance-calculator
Extend distance calculator with member function that calculates close…
-rw-r--r--searchlib/src/tests/tensor/distance_calculator/distance_calculator_test.cpp13
-rw-r--r--searchlib/src/vespa/searchlib/tensor/distance_calculator.h14
2 files changed, 27 insertions, 0 deletions
diff --git a/searchlib/src/tests/tensor/distance_calculator/distance_calculator_test.cpp b/searchlib/src/tests/tensor/distance_calculator/distance_calculator_test.cpp
index 5e556979254..ef4292ddbb4 100644
--- a/searchlib/src/tests/tensor/distance_calculator/distance_calculator_test.cpp
+++ b/searchlib/src/tests/tensor/distance_calculator/distance_calculator_test.cpp
@@ -18,6 +18,8 @@ using namespace vespalib::eval;
using search::AttributeVector;
+using OptSubspace = std::optional<uint32_t>;
+
std::unique_ptr<Value> make_tensor(const vespalib::string& expr) {
return SimpleValue::from_spec(TensorSpec::from_expr(expr));
}
@@ -49,6 +51,11 @@ public:
auto calc = DistanceCalculator::make_with_validation(*attr, *qt);
return calc->calc_raw_score(docid);
}
+ OptSubspace calc_closest_subspace(uint32_t docid, const vespalib::string& query_tensor) {
+ auto qt = make_tensor(query_tensor);
+ auto calc = DistanceCalculator::make_with_validation(*attr, *qt);
+ return calc->calc_closest_subspace(attr->asTensorAttribute()->get_vectors(docid));
+ }
void make_calc_throws(const vespalib::string& query_tensor) {
auto qt = make_tensor(query_tensor);
DistanceCalculator::make_with_validation(*attr, *qt);
@@ -63,9 +70,11 @@ TEST_F(DistanceCalculatorTest, calculation_over_dense_tensor_attribute)
vespalib::string qt = "tensor(y[2]):[7,10]";
EXPECT_DOUBLE_EQ(16, calc_distance(1, qt));
EXPECT_DOUBLE_EQ(max_distance, calc_distance(2, qt));
+ EXPECT_EQ(OptSubspace(0), calc_closest_subspace(1, qt));
EXPECT_DOUBLE_EQ(1.0/(1.0 + 4.0), calc_rawscore(1, qt));
EXPECT_DOUBLE_EQ(0.0, calc_rawscore(2, qt));
+ EXPECT_EQ(OptSubspace(), calc_closest_subspace(2, qt));
}
TEST_F(DistanceCalculatorTest, calculation_over_mixed_tensor_attribute)
@@ -77,8 +86,12 @@ TEST_F(DistanceCalculatorTest, calculation_over_mixed_tensor_attribute)
vespalib::string qt_2 = "tensor(y[2]):[1,10]";
EXPECT_DOUBLE_EQ(16, calc_distance(1, qt_1));
EXPECT_DOUBLE_EQ(4, calc_distance(1, qt_2));
+ EXPECT_EQ(OptSubspace(1), calc_closest_subspace(1, qt_1));
+ EXPECT_EQ(OptSubspace(0), calc_closest_subspace(1, qt_2));
EXPECT_DOUBLE_EQ(max_distance, calc_distance(2, qt_1));
EXPECT_DOUBLE_EQ(max_distance, calc_distance(3, qt_1));
+ EXPECT_EQ(OptSubspace(), calc_closest_subspace(2, qt_1));
+ EXPECT_EQ(OptSubspace(), calc_closest_subspace(3, qt_1));
EXPECT_DOUBLE_EQ(1.0/(1.0 + 4.0), calc_rawscore(1, qt_1));
EXPECT_DOUBLE_EQ(1.0/(1.0 + 2.0), calc_rawscore(1, qt_2));
diff --git a/searchlib/src/vespa/searchlib/tensor/distance_calculator.h b/searchlib/src/vespa/searchlib/tensor/distance_calculator.h
index 320f071cbbb..f501b004254 100644
--- a/searchlib/src/vespa/searchlib/tensor/distance_calculator.h
+++ b/searchlib/src/vespa/searchlib/tensor/distance_calculator.h
@@ -4,6 +4,7 @@
#include "distance_function.h"
#include "i_tensor_attribute.h"
#include "vector_bundle.h"
+#include <optional>
namespace vespalib::eval { struct Value; }
@@ -64,6 +65,19 @@ public:
return result;
}
+ std::optional<uint32_t> calc_closest_subspace(VectorBundle vectors) {
+ double best_distance = 0.0;
+ std::optional<uint32_t> closest_subspace;
+ for (uint32_t i = 0; i < vectors.subspaces(); ++i) {
+ double distance = _dist_fun->calc(_query_tensor_cells, vectors.cells(i));
+ if (!closest_subspace.has_value() || distance < best_distance) {
+ best_distance = distance;
+ closest_subspace = i;
+ }
+ }
+ return closest_subspace;
+ }
+
/**
* Create a calculator for the given attribute tensor and query tensor, if possible.
*