summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--eval/src/tests/eval/value_type/value_type_test.cpp12
-rw-r--r--eval/src/vespa/eval/eval/value_type.cpp11
-rw-r--r--eval/src/vespa/eval/eval/value_type.h1
-rw-r--r--searchlib/src/tests/tensor/distance_calculator/distance_calculator_test.cpp45
-rw-r--r--searchlib/src/vespa/searchlib/tensor/distance_calculator.cpp32
5 files changed, 75 insertions, 26 deletions
diff --git a/eval/src/tests/eval/value_type/value_type_test.cpp b/eval/src/tests/eval/value_type/value_type_test.cpp
index 245c9c30242..feb6fb5c368 100644
--- a/eval/src/tests/eval/value_type/value_type_test.cpp
+++ b/eval/src/tests/eval/value_type/value_type_test.cpp
@@ -327,6 +327,18 @@ TEST("require that nontrivial indexed dimensions can be obtained") {
TEST_DO(my_check(type("tensor(a[1],b[1],x[10],y{},z[1])").nontrivial_indexed_dimensions()));
}
+TEST("require that indexed dimensions can be obtained") {
+ auto my_check = [](const auto &list, size_t exp_size)
+ {
+ ASSERT_EQUAL(list.size(), 1u);
+ EXPECT_EQUAL(list[0].name, "x");
+ EXPECT_EQUAL(list[0].size, exp_size);
+ };
+ EXPECT_TRUE(type("double").indexed_dimensions().empty());
+ TEST_DO(my_check(type("tensor(x[10],y{})").indexed_dimensions(), 10));
+ TEST_DO(my_check(type("tensor(y{},x[1])").indexed_dimensions(), 1));
+}
+
TEST("require that mapped dimensions can be obtained") {
auto my_check = [](const auto &list)
{
diff --git a/eval/src/vespa/eval/eval/value_type.cpp b/eval/src/vespa/eval/eval/value_type.cpp
index dc5ce645a8c..7d088b22e06 100644
--- a/eval/src/vespa/eval/eval/value_type.cpp
+++ b/eval/src/vespa/eval/eval/value_type.cpp
@@ -251,6 +251,17 @@ ValueType::nontrivial_indexed_dimensions() const {
}
std::vector<ValueType::Dimension>
+ValueType::indexed_dimensions() const {
+ std::vector<ValueType::Dimension> result;
+ for (const auto &dim: dimensions()) {
+ if (dim.is_indexed()) {
+ result.push_back(dim);
+ }
+ }
+ return result;
+}
+
+std::vector<ValueType::Dimension>
ValueType::mapped_dimensions() const {
std::vector<ValueType::Dimension> result;
for (const auto &dim: dimensions()) {
diff --git a/eval/src/vespa/eval/eval/value_type.h b/eval/src/vespa/eval/eval/value_type.h
index b7a7c92e137..49f88edb2f9 100644
--- a/eval/src/vespa/eval/eval/value_type.h
+++ b/eval/src/vespa/eval/eval/value_type.h
@@ -66,6 +66,7 @@ public:
size_t dense_subspace_size() const;
const std::vector<Dimension> &dimensions() const { return _dimensions; }
std::vector<Dimension> nontrivial_indexed_dimensions() const;
+ std::vector<Dimension> indexed_dimensions() const;
std::vector<Dimension> mapped_dimensions() const;
size_t dimension_index(const vespalib::string &name) const;
std::vector<vespalib::string> dimension_names() const;
diff --git a/searchlib/src/tests/tensor/distance_calculator/distance_calculator_test.cpp b/searchlib/src/tests/tensor/distance_calculator/distance_calculator_test.cpp
index 11f767b546d..5e556979254 100644
--- a/searchlib/src/tests/tensor/distance_calculator/distance_calculator_test.cpp
+++ b/searchlib/src/tests/tensor/distance_calculator/distance_calculator_test.cpp
@@ -6,9 +6,9 @@
#include <vespa/searchcommon/attribute/config.h>
#include <vespa/searchlib/attribute/attributevector.h>
#include <vespa/searchlib/tensor/distance_calculator.h>
-#include <vespa/searchlib/tensor/distance_function_factory.h>
#include <vespa/searchlib/test/attribute_builder.h>
#include <vespa/vespalib/gtest/gtest.h>
+#include <vespa/vespalib/util/exceptions.h>
#include <iostream>
using namespace search::attribute::test;
@@ -18,16 +18,16 @@ using namespace vespalib::eval;
using search::AttributeVector;
+std::unique_ptr<Value> make_tensor(const vespalib::string& expr) {
+ return SimpleValue::from_spec(TensorSpec::from_expr(expr));
+}
+
class DistanceCalculatorTest : public testing::Test {
public:
std::shared_ptr<AttributeVector> attr;
- const ITensorAttribute* attr_tensor;
- std::unique_ptr<DistanceFunction> func;
DistanceCalculatorTest()
- : attr(),
- attr_tensor(),
- func(make_distance_function(DistanceMetric::Euclidean, CellType::DOUBLE))
+ : attr()
{
}
@@ -35,19 +35,23 @@ public:
const std::vector<vespalib::string>& tensor_values) {
Config cfg(BasicType::TENSOR);
cfg.setTensorType(ValueType::from_spec(tensor_type));
+ cfg.set_distance_metric(DistanceMetric::Euclidean);
attr = AttributeBuilder("doc_tensor", cfg).fill_tensor(tensor_values).get();
- attr_tensor = dynamic_cast<const ITensorAttribute*>(attr.get());
- ASSERT_TRUE(attr_tensor != nullptr);
+ ASSERT_TRUE(attr.get() != nullptr);
}
double calc_distance(uint32_t docid, const vespalib::string& query_tensor) {
- auto qv = SimpleValue::from_spec(TensorSpec::from_expr(query_tensor));
- DistanceCalculator calc(*attr_tensor, *qv, *func);
- return calc.calc_with_limit(docid, std::numeric_limits<double>::max());
+ auto qt = make_tensor(query_tensor);
+ auto calc = DistanceCalculator::make_with_validation(*attr, *qt);
+ return calc->calc_with_limit(docid, std::numeric_limits<double>::max());
}
double calc_rawscore(uint32_t docid, const vespalib::string& query_tensor) {
- auto qv = SimpleValue::from_spec(TensorSpec::from_expr(query_tensor));
- DistanceCalculator calc(*attr_tensor, *qv, *func);
- return calc.calc_raw_score(docid);
+ auto qt = make_tensor(query_tensor);
+ auto calc = DistanceCalculator::make_with_validation(*attr, *qt);
+ return calc->calc_raw_score(docid);
+ }
+ void make_calc_throws(const vespalib::string& query_tensor) {
+ auto qt = make_tensor(query_tensor);
+ DistanceCalculator::make_with_validation(*attr, *qt);
}
};
@@ -82,5 +86,18 @@ TEST_F(DistanceCalculatorTest, calculation_over_mixed_tensor_attribute)
EXPECT_DOUBLE_EQ(0.0, calc_rawscore(3, qt_1));
}
+TEST_F(DistanceCalculatorTest, make_calculator_for_unsupported_types_throws)
+{
+ build_attribute("tensor(x{},y{})", {});
+ EXPECT_THROW(make_calc_throws("tensor(y[2]):[9,10]"), vespalib::IllegalArgumentException);
+
+ build_attribute("tensor(x{},y{},z[2])", {});
+ EXPECT_THROW(make_calc_throws("tensor(z[2]):[9,10]"), vespalib::IllegalArgumentException);
+
+ build_attribute("tensor(x{},y[2])", {});
+ EXPECT_THROW(make_calc_throws("tensor(y{}):{{y:\"a\"}:9,{y:\"b\"}:10}"), vespalib::IllegalArgumentException);
+ EXPECT_THROW(make_calc_throws("tensor(y[3]):[9,10]"), vespalib::IllegalArgumentException);
+}
+
GTEST_MAIN_RUN_ALL_TESTS()
diff --git a/searchlib/src/vespa/searchlib/tensor/distance_calculator.cpp b/searchlib/src/vespa/searchlib/tensor/distance_calculator.cpp
index d6d5433ff15..b669b5ffea6 100644
--- a/searchlib/src/vespa/searchlib/tensor/distance_calculator.cpp
+++ b/searchlib/src/vespa/searchlib/tensor/distance_calculator.cpp
@@ -47,12 +47,7 @@ struct ConvertCellsSelector
}
};
-bool
-is_compatible(const vespalib::eval::ValueType& lhs,
- const vespalib::eval::ValueType& rhs)
-{
- return (lhs.dimensions() == rhs.dimensions());
-}
+
}
@@ -100,7 +95,24 @@ DistanceCalculator::~DistanceCalculator() = default;
namespace {
+bool
+supported_tensor_type(const vespalib::eval::ValueType& type)
+{
+ if (type.is_dense() && type.dimensions().size() == 1) {
+ return true;
+ }
+ if (type.is_mixed() && type.dimensions().size() == 2) {
+ return true;
+ }
+ return false;
+}
+bool
+is_compatible(const vespalib::eval::ValueType& lhs,
+ const vespalib::eval::ValueType& rhs)
+{
+ return (lhs.indexed_dimensions() == rhs.indexed_dimensions());
+}
}
@@ -113,8 +125,8 @@ DistanceCalculator::make_with_validation(const search::attribute::IAttributeVect
throw IllegalArgumentException("Attribute is not a tensor");
}
const auto& at_type = attr_tensor->getTensorType();
- if ((!at_type.is_dense()) || (at_type.dimensions().size() != 1)) {
- throw IllegalArgumentException(make_string("Attribute tensor type (%s) is not a dense tensor of order 1",
+ if (!supported_tensor_type(at_type)) {
+ throw IllegalArgumentException(make_string("Attribute tensor type (%s) is not supported",
at_type.to_spec().c_str()));
}
const auto& qt_type = query_tensor_in.type();
@@ -126,10 +138,6 @@ DistanceCalculator::make_with_validation(const search::attribute::IAttributeVect
throw IllegalArgumentException(make_string("Attribute tensor type (%s) and query tensor type (%s) are not compatible",
at_type.to_spec().c_str(), qt_type.to_spec().c_str()));
}
- if (!attr_tensor->supports_extract_cells_ref()) {
- throw IllegalArgumentException(make_string("Attribute tensor does not support access to tensor data (type=%s)",
- at_type.to_spec().c_str()));
- }
return std::make_unique<DistanceCalculator>(*attr_tensor, query_tensor_in);
}