diff options
author | Geir Storli <geirst@yahooinc.com> | 2022-11-25 15:41:00 +0000 |
---|---|---|
committer | Geir Storli <geirst@yahooinc.com> | 2022-11-25 15:41:00 +0000 |
commit | 45518057c9d373af2616dbde11e883c05ca96859 (patch) | |
tree | c46fcad3aeeaeddb60aa3d9421d05126c393d78d | |
parent | 3038c01d054748d4d2422b7f3cfe2b38f664de6c (diff) |
Support mixed tensor attribute with 2 dimensions when creating distance calculator.
5 files changed, 75 insertions, 26 deletions
diff --git a/eval/src/tests/eval/value_type/value_type_test.cpp b/eval/src/tests/eval/value_type/value_type_test.cpp index 245c9c30242..feb6fb5c368 100644 --- a/eval/src/tests/eval/value_type/value_type_test.cpp +++ b/eval/src/tests/eval/value_type/value_type_test.cpp @@ -327,6 +327,18 @@ TEST("require that nontrivial indexed dimensions can be obtained") { TEST_DO(my_check(type("tensor(a[1],b[1],x[10],y{},z[1])").nontrivial_indexed_dimensions())); } +TEST("require that indexed dimensions can be obtained") { + auto my_check = [](const auto &list, size_t exp_size) + { + ASSERT_EQUAL(list.size(), 1u); + EXPECT_EQUAL(list[0].name, "x"); + EXPECT_EQUAL(list[0].size, exp_size); + }; + EXPECT_TRUE(type("double").indexed_dimensions().empty()); + TEST_DO(my_check(type("tensor(x[10],y{})").indexed_dimensions(), 10)); + TEST_DO(my_check(type("tensor(y{},x[1])").indexed_dimensions(), 1)); +} + TEST("require that mapped dimensions can be obtained") { auto my_check = [](const auto &list) { diff --git a/eval/src/vespa/eval/eval/value_type.cpp b/eval/src/vespa/eval/eval/value_type.cpp index dc5ce645a8c..7d088b22e06 100644 --- a/eval/src/vespa/eval/eval/value_type.cpp +++ b/eval/src/vespa/eval/eval/value_type.cpp @@ -251,6 +251,17 @@ ValueType::nontrivial_indexed_dimensions() const { } std::vector<ValueType::Dimension> +ValueType::indexed_dimensions() const { + std::vector<ValueType::Dimension> result; + for (const auto &dim: dimensions()) { + if (dim.is_indexed()) { + result.push_back(dim); + } + } + return result; +} + +std::vector<ValueType::Dimension> ValueType::mapped_dimensions() const { std::vector<ValueType::Dimension> result; for (const auto &dim: dimensions()) { diff --git a/eval/src/vespa/eval/eval/value_type.h b/eval/src/vespa/eval/eval/value_type.h index b7a7c92e137..49f88edb2f9 100644 --- a/eval/src/vespa/eval/eval/value_type.h +++ b/eval/src/vespa/eval/eval/value_type.h @@ -66,6 +66,7 @@ public: size_t dense_subspace_size() const; const std::vector<Dimension> &dimensions() const { return _dimensions; } std::vector<Dimension> nontrivial_indexed_dimensions() const; + std::vector<Dimension> indexed_dimensions() const; std::vector<Dimension> mapped_dimensions() const; size_t dimension_index(const vespalib::string &name) const; std::vector<vespalib::string> dimension_names() const; diff --git a/searchlib/src/tests/tensor/distance_calculator/distance_calculator_test.cpp b/searchlib/src/tests/tensor/distance_calculator/distance_calculator_test.cpp index 11f767b546d..5e556979254 100644 --- a/searchlib/src/tests/tensor/distance_calculator/distance_calculator_test.cpp +++ b/searchlib/src/tests/tensor/distance_calculator/distance_calculator_test.cpp @@ -6,9 +6,9 @@ #include <vespa/searchcommon/attribute/config.h> #include <vespa/searchlib/attribute/attributevector.h> #include <vespa/searchlib/tensor/distance_calculator.h> -#include <vespa/searchlib/tensor/distance_function_factory.h> #include <vespa/searchlib/test/attribute_builder.h> #include <vespa/vespalib/gtest/gtest.h> +#include <vespa/vespalib/util/exceptions.h> #include <iostream> using namespace search::attribute::test; @@ -18,16 +18,16 @@ using namespace vespalib::eval; using search::AttributeVector; +std::unique_ptr<Value> make_tensor(const vespalib::string& expr) { + return SimpleValue::from_spec(TensorSpec::from_expr(expr)); +} + class DistanceCalculatorTest : public testing::Test { public: std::shared_ptr<AttributeVector> attr; - const ITensorAttribute* attr_tensor; - std::unique_ptr<DistanceFunction> func; DistanceCalculatorTest() - : attr(), - attr_tensor(), - func(make_distance_function(DistanceMetric::Euclidean, CellType::DOUBLE)) + : attr() { } @@ -35,19 +35,23 @@ public: const std::vector<vespalib::string>& tensor_values) { Config cfg(BasicType::TENSOR); cfg.setTensorType(ValueType::from_spec(tensor_type)); + cfg.set_distance_metric(DistanceMetric::Euclidean); attr = AttributeBuilder("doc_tensor", cfg).fill_tensor(tensor_values).get(); - attr_tensor = dynamic_cast<const ITensorAttribute*>(attr.get()); - ASSERT_TRUE(attr_tensor != nullptr); + ASSERT_TRUE(attr.get() != nullptr); } double calc_distance(uint32_t docid, const vespalib::string& query_tensor) { - auto qv = SimpleValue::from_spec(TensorSpec::from_expr(query_tensor)); - DistanceCalculator calc(*attr_tensor, *qv, *func); - return calc.calc_with_limit(docid, std::numeric_limits<double>::max()); + auto qt = make_tensor(query_tensor); + auto calc = DistanceCalculator::make_with_validation(*attr, *qt); + return calc->calc_with_limit(docid, std::numeric_limits<double>::max()); } double calc_rawscore(uint32_t docid, const vespalib::string& query_tensor) { - auto qv = SimpleValue::from_spec(TensorSpec::from_expr(query_tensor)); - DistanceCalculator calc(*attr_tensor, *qv, *func); - return calc.calc_raw_score(docid); + auto qt = make_tensor(query_tensor); + auto calc = DistanceCalculator::make_with_validation(*attr, *qt); + return calc->calc_raw_score(docid); + } + void make_calc_throws(const vespalib::string& query_tensor) { + auto qt = make_tensor(query_tensor); + DistanceCalculator::make_with_validation(*attr, *qt); } }; @@ -82,5 +86,18 @@ TEST_F(DistanceCalculatorTest, calculation_over_mixed_tensor_attribute) EXPECT_DOUBLE_EQ(0.0, calc_rawscore(3, qt_1)); } +TEST_F(DistanceCalculatorTest, make_calculator_for_unsupported_types_throws) +{ + build_attribute("tensor(x{},y{})", {}); + EXPECT_THROW(make_calc_throws("tensor(y[2]):[9,10]"), vespalib::IllegalArgumentException); + + build_attribute("tensor(x{},y{},z[2])", {}); + EXPECT_THROW(make_calc_throws("tensor(z[2]):[9,10]"), vespalib::IllegalArgumentException); + + build_attribute("tensor(x{},y[2])", {}); + EXPECT_THROW(make_calc_throws("tensor(y{}):{{y:\"a\"}:9,{y:\"b\"}:10}"), vespalib::IllegalArgumentException); + EXPECT_THROW(make_calc_throws("tensor(y[3]):[9,10]"), vespalib::IllegalArgumentException); +} + GTEST_MAIN_RUN_ALL_TESTS() diff --git a/searchlib/src/vespa/searchlib/tensor/distance_calculator.cpp b/searchlib/src/vespa/searchlib/tensor/distance_calculator.cpp index d6d5433ff15..b669b5ffea6 100644 --- a/searchlib/src/vespa/searchlib/tensor/distance_calculator.cpp +++ b/searchlib/src/vespa/searchlib/tensor/distance_calculator.cpp @@ -47,12 +47,7 @@ struct ConvertCellsSelector } }; -bool -is_compatible(const vespalib::eval::ValueType& lhs, - const vespalib::eval::ValueType& rhs) -{ - return (lhs.dimensions() == rhs.dimensions()); -} + } @@ -100,7 +95,24 @@ DistanceCalculator::~DistanceCalculator() = default; namespace { +bool +supported_tensor_type(const vespalib::eval::ValueType& type) +{ + if (type.is_dense() && type.dimensions().size() == 1) { + return true; + } + if (type.is_mixed() && type.dimensions().size() == 2) { + return true; + } + return false; +} +bool +is_compatible(const vespalib::eval::ValueType& lhs, + const vespalib::eval::ValueType& rhs) +{ + return (lhs.indexed_dimensions() == rhs.indexed_dimensions()); +} } @@ -113,8 +125,8 @@ DistanceCalculator::make_with_validation(const search::attribute::IAttributeVect throw IllegalArgumentException("Attribute is not a tensor"); } const auto& at_type = attr_tensor->getTensorType(); - if ((!at_type.is_dense()) || (at_type.dimensions().size() != 1)) { - throw IllegalArgumentException(make_string("Attribute tensor type (%s) is not a dense tensor of order 1", + if (!supported_tensor_type(at_type)) { + throw IllegalArgumentException(make_string("Attribute tensor type (%s) is not supported", at_type.to_spec().c_str())); } const auto& qt_type = query_tensor_in.type(); @@ -126,10 +138,6 @@ DistanceCalculator::make_with_validation(const search::attribute::IAttributeVect throw IllegalArgumentException(make_string("Attribute tensor type (%s) and query tensor type (%s) are not compatible", at_type.to_spec().c_str(), qt_type.to_spec().c_str())); } - if (!attr_tensor->supports_extract_cells_ref()) { - throw IllegalArgumentException(make_string("Attribute tensor does not support access to tensor data (type=%s)", - at_type.to_spec().c_str())); - } return std::make_unique<DistanceCalculator>(*attr_tensor, query_tensor_in); } |