diff options
Diffstat (limited to 'searchlib/src/tests/tensor/distance_calculator/distance_calculator_test.cpp')
-rw-r--r-- | searchlib/src/tests/tensor/distance_calculator/distance_calculator_test.cpp | 13 |
1 files changed, 13 insertions, 0 deletions
diff --git a/searchlib/src/tests/tensor/distance_calculator/distance_calculator_test.cpp b/searchlib/src/tests/tensor/distance_calculator/distance_calculator_test.cpp index 5e556979254..ef4292ddbb4 100644 --- a/searchlib/src/tests/tensor/distance_calculator/distance_calculator_test.cpp +++ b/searchlib/src/tests/tensor/distance_calculator/distance_calculator_test.cpp @@ -18,6 +18,8 @@ using namespace vespalib::eval; using search::AttributeVector; +using OptSubspace = std::optional<uint32_t>; + std::unique_ptr<Value> make_tensor(const vespalib::string& expr) { return SimpleValue::from_spec(TensorSpec::from_expr(expr)); } @@ -49,6 +51,11 @@ public: auto calc = DistanceCalculator::make_with_validation(*attr, *qt); return calc->calc_raw_score(docid); } + OptSubspace calc_closest_subspace(uint32_t docid, const vespalib::string& query_tensor) { + auto qt = make_tensor(query_tensor); + auto calc = DistanceCalculator::make_with_validation(*attr, *qt); + return calc->calc_closest_subspace(attr->asTensorAttribute()->get_vectors(docid)); + } void make_calc_throws(const vespalib::string& query_tensor) { auto qt = make_tensor(query_tensor); DistanceCalculator::make_with_validation(*attr, *qt); @@ -63,9 +70,11 @@ TEST_F(DistanceCalculatorTest, calculation_over_dense_tensor_attribute) vespalib::string qt = "tensor(y[2]):[7,10]"; EXPECT_DOUBLE_EQ(16, calc_distance(1, qt)); EXPECT_DOUBLE_EQ(max_distance, calc_distance(2, qt)); + EXPECT_EQ(OptSubspace(0), calc_closest_subspace(1, qt)); EXPECT_DOUBLE_EQ(1.0/(1.0 + 4.0), calc_rawscore(1, qt)); EXPECT_DOUBLE_EQ(0.0, calc_rawscore(2, qt)); + EXPECT_EQ(OptSubspace(), calc_closest_subspace(2, qt)); } TEST_F(DistanceCalculatorTest, calculation_over_mixed_tensor_attribute) @@ -77,8 +86,12 @@ TEST_F(DistanceCalculatorTest, calculation_over_mixed_tensor_attribute) vespalib::string qt_2 = "tensor(y[2]):[1,10]"; EXPECT_DOUBLE_EQ(16, calc_distance(1, qt_1)); EXPECT_DOUBLE_EQ(4, calc_distance(1, qt_2)); + EXPECT_EQ(OptSubspace(1), calc_closest_subspace(1, qt_1)); + EXPECT_EQ(OptSubspace(0), calc_closest_subspace(1, qt_2)); EXPECT_DOUBLE_EQ(max_distance, calc_distance(2, qt_1)); EXPECT_DOUBLE_EQ(max_distance, calc_distance(3, qt_1)); + EXPECT_EQ(OptSubspace(), calc_closest_subspace(2, qt_1)); + EXPECT_EQ(OptSubspace(), calc_closest_subspace(3, qt_1)); EXPECT_DOUBLE_EQ(1.0/(1.0 + 4.0), calc_rawscore(1, qt_1)); EXPECT_DOUBLE_EQ(1.0/(1.0 + 2.0), calc_rawscore(1, qt_2)); |