summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorTor Egge <Tor.Egge@online.no>2023-02-23 13:58:20 +0100
committerTor Egge <Tor.Egge@online.no>2023-02-23 13:58:20 +0100
commit4d09377220c1c2d450b33455380cbcc00fd8a890 (patch)
treef8b9a3a46c96c9ff4c0174fbcd337368e782e34c
parentf66f816102ce0a7c3aaba72d1db61a83157259ed (diff)
Extend distance calculator with member function that calculates closest subspace.
-rw-r--r--searchlib/src/tests/tensor/distance_calculator/distance_calculator_test.cpp13
-rw-r--r--searchlib/src/vespa/searchlib/tensor/distance_calculator.h14
2 files changed, 27 insertions, 0 deletions
diff --git a/searchlib/src/tests/tensor/distance_calculator/distance_calculator_test.cpp b/searchlib/src/tests/tensor/distance_calculator/distance_calculator_test.cpp
index 5e556979254..ef4292ddbb4 100644
--- a/searchlib/src/tests/tensor/distance_calculator/distance_calculator_test.cpp
+++ b/searchlib/src/tests/tensor/distance_calculator/distance_calculator_test.cpp
@@ -18,6 +18,8 @@ using namespace vespalib::eval;
using search::AttributeVector;
+using OptSubspace = std::optional<uint32_t>;
+
std::unique_ptr<Value> make_tensor(const vespalib::string& expr) {
return SimpleValue::from_spec(TensorSpec::from_expr(expr));
}
@@ -49,6 +51,11 @@ public:
auto calc = DistanceCalculator::make_with_validation(*attr, *qt);
return calc->calc_raw_score(docid);
}
+ OptSubspace calc_closest_subspace(uint32_t docid, const vespalib::string& query_tensor) {
+ auto qt = make_tensor(query_tensor);
+ auto calc = DistanceCalculator::make_with_validation(*attr, *qt);
+ return calc->calc_closest_subspace(attr->asTensorAttribute()->get_vectors(docid));
+ }
void make_calc_throws(const vespalib::string& query_tensor) {
auto qt = make_tensor(query_tensor);
DistanceCalculator::make_with_validation(*attr, *qt);
@@ -63,9 +70,11 @@ TEST_F(DistanceCalculatorTest, calculation_over_dense_tensor_attribute)
vespalib::string qt = "tensor(y[2]):[7,10]";
EXPECT_DOUBLE_EQ(16, calc_distance(1, qt));
EXPECT_DOUBLE_EQ(max_distance, calc_distance(2, qt));
+ EXPECT_EQ(OptSubspace(0), calc_closest_subspace(1, qt));
EXPECT_DOUBLE_EQ(1.0/(1.0 + 4.0), calc_rawscore(1, qt));
EXPECT_DOUBLE_EQ(0.0, calc_rawscore(2, qt));
+ EXPECT_EQ(OptSubspace(), calc_closest_subspace(2, qt));
}
TEST_F(DistanceCalculatorTest, calculation_over_mixed_tensor_attribute)
@@ -77,8 +86,12 @@ TEST_F(DistanceCalculatorTest, calculation_over_mixed_tensor_attribute)
vespalib::string qt_2 = "tensor(y[2]):[1,10]";
EXPECT_DOUBLE_EQ(16, calc_distance(1, qt_1));
EXPECT_DOUBLE_EQ(4, calc_distance(1, qt_2));
+ EXPECT_EQ(OptSubspace(1), calc_closest_subspace(1, qt_1));
+ EXPECT_EQ(OptSubspace(0), calc_closest_subspace(1, qt_2));
EXPECT_DOUBLE_EQ(max_distance, calc_distance(2, qt_1));
EXPECT_DOUBLE_EQ(max_distance, calc_distance(3, qt_1));
+ EXPECT_EQ(OptSubspace(), calc_closest_subspace(2, qt_1));
+ EXPECT_EQ(OptSubspace(), calc_closest_subspace(3, qt_1));
EXPECT_DOUBLE_EQ(1.0/(1.0 + 4.0), calc_rawscore(1, qt_1));
EXPECT_DOUBLE_EQ(1.0/(1.0 + 2.0), calc_rawscore(1, qt_2));
diff --git a/searchlib/src/vespa/searchlib/tensor/distance_calculator.h b/searchlib/src/vespa/searchlib/tensor/distance_calculator.h
index 320f071cbbb..f501b004254 100644
--- a/searchlib/src/vespa/searchlib/tensor/distance_calculator.h
+++ b/searchlib/src/vespa/searchlib/tensor/distance_calculator.h
@@ -4,6 +4,7 @@
#include "distance_function.h"
#include "i_tensor_attribute.h"
#include "vector_bundle.h"
+#include <optional>
namespace vespalib::eval { struct Value; }
@@ -64,6 +65,19 @@ public:
return result;
}
+ std::optional<uint32_t> calc_closest_subspace(VectorBundle vectors) {
+ double best_distance = 0.0;
+ std::optional<uint32_t> closest_subspace;
+ for (uint32_t i = 0; i < vectors.subspaces(); ++i) {
+ double distance = _dist_fun->calc(_query_tensor_cells, vectors.cells(i));
+ if (!closest_subspace.has_value() || distance < best_distance) {
+ best_distance = distance;
+ closest_subspace = i;
+ }
+ }
+ return closest_subspace;
+ }
+
/**
* Create a calculator for the given attribute tensor and query tensor, if possible.
*