diff options
author | Tor Egge <Tor.Egge@online.no> | 2023-02-23 13:58:20 +0100 |
---|---|---|
committer | Tor Egge <Tor.Egge@online.no> | 2023-02-23 13:58:20 +0100 |
commit | 4d09377220c1c2d450b33455380cbcc00fd8a890 (patch) | |
tree | f8b9a3a46c96c9ff4c0174fbcd337368e782e34c /searchlib | |
parent | f66f816102ce0a7c3aaba72d1db61a83157259ed (diff) |
Extend distance calculator with member function that calculates closest subspace.
Diffstat (limited to 'searchlib')
-rw-r--r-- | searchlib/src/tests/tensor/distance_calculator/distance_calculator_test.cpp | 13 | ||||
-rw-r--r-- | searchlib/src/vespa/searchlib/tensor/distance_calculator.h | 14 |
2 files changed, 27 insertions, 0 deletions
diff --git a/searchlib/src/tests/tensor/distance_calculator/distance_calculator_test.cpp b/searchlib/src/tests/tensor/distance_calculator/distance_calculator_test.cpp index 5e556979254..ef4292ddbb4 100644 --- a/searchlib/src/tests/tensor/distance_calculator/distance_calculator_test.cpp +++ b/searchlib/src/tests/tensor/distance_calculator/distance_calculator_test.cpp @@ -18,6 +18,8 @@ using namespace vespalib::eval; using search::AttributeVector; +using OptSubspace = std::optional<uint32_t>; + std::unique_ptr<Value> make_tensor(const vespalib::string& expr) { return SimpleValue::from_spec(TensorSpec::from_expr(expr)); } @@ -49,6 +51,11 @@ public: auto calc = DistanceCalculator::make_with_validation(*attr, *qt); return calc->calc_raw_score(docid); } + OptSubspace calc_closest_subspace(uint32_t docid, const vespalib::string& query_tensor) { + auto qt = make_tensor(query_tensor); + auto calc = DistanceCalculator::make_with_validation(*attr, *qt); + return calc->calc_closest_subspace(attr->asTensorAttribute()->get_vectors(docid)); + } void make_calc_throws(const vespalib::string& query_tensor) { auto qt = make_tensor(query_tensor); DistanceCalculator::make_with_validation(*attr, *qt); @@ -63,9 +70,11 @@ TEST_F(DistanceCalculatorTest, calculation_over_dense_tensor_attribute) vespalib::string qt = "tensor(y[2]):[7,10]"; EXPECT_DOUBLE_EQ(16, calc_distance(1, qt)); EXPECT_DOUBLE_EQ(max_distance, calc_distance(2, qt)); + EXPECT_EQ(OptSubspace(0), calc_closest_subspace(1, qt)); EXPECT_DOUBLE_EQ(1.0/(1.0 + 4.0), calc_rawscore(1, qt)); EXPECT_DOUBLE_EQ(0.0, calc_rawscore(2, qt)); + EXPECT_EQ(OptSubspace(), calc_closest_subspace(2, qt)); } TEST_F(DistanceCalculatorTest, calculation_over_mixed_tensor_attribute) @@ -77,8 +86,12 @@ TEST_F(DistanceCalculatorTest, calculation_over_mixed_tensor_attribute) vespalib::string qt_2 = "tensor(y[2]):[1,10]"; EXPECT_DOUBLE_EQ(16, calc_distance(1, qt_1)); EXPECT_DOUBLE_EQ(4, calc_distance(1, qt_2)); + EXPECT_EQ(OptSubspace(1), calc_closest_subspace(1, qt_1)); + EXPECT_EQ(OptSubspace(0), calc_closest_subspace(1, qt_2)); EXPECT_DOUBLE_EQ(max_distance, calc_distance(2, qt_1)); EXPECT_DOUBLE_EQ(max_distance, calc_distance(3, qt_1)); + EXPECT_EQ(OptSubspace(), calc_closest_subspace(2, qt_1)); + EXPECT_EQ(OptSubspace(), calc_closest_subspace(3, qt_1)); EXPECT_DOUBLE_EQ(1.0/(1.0 + 4.0), calc_rawscore(1, qt_1)); EXPECT_DOUBLE_EQ(1.0/(1.0 + 2.0), calc_rawscore(1, qt_2)); diff --git a/searchlib/src/vespa/searchlib/tensor/distance_calculator.h b/searchlib/src/vespa/searchlib/tensor/distance_calculator.h index 320f071cbbb..f501b004254 100644 --- a/searchlib/src/vespa/searchlib/tensor/distance_calculator.h +++ b/searchlib/src/vespa/searchlib/tensor/distance_calculator.h @@ -4,6 +4,7 @@ #include "distance_function.h" #include "i_tensor_attribute.h" #include "vector_bundle.h" +#include <optional> namespace vespalib::eval { struct Value; } @@ -64,6 +65,19 @@ public: return result; } + std::optional<uint32_t> calc_closest_subspace(VectorBundle vectors) { + double best_distance = 0.0; + std::optional<uint32_t> closest_subspace; + for (uint32_t i = 0; i < vectors.subspaces(); ++i) { + double distance = _dist_fun->calc(_query_tensor_cells, vectors.cells(i)); + if (!closest_subspace.has_value() || distance < best_distance) { + best_distance = distance; + closest_subspace = i; + } + } + return closest_subspace; + } + /** * Create a calculator for the given attribute tensor and query tensor, if possible. * |