aboutsummaryrefslogtreecommitdiffstats
path: root/eval
diff options
context:
space:
mode:
authorGeir Storli <geirst@yahooinc.com>2022-11-25 15:41:00 +0000
committerGeir Storli <geirst@yahooinc.com>2022-11-25 15:41:00 +0000
commit45518057c9d373af2616dbde11e883c05ca96859 (patch)
treec46fcad3aeeaeddb60aa3d9421d05126c393d78d /eval
parent3038c01d054748d4d2422b7f3cfe2b38f664de6c (diff)
Support mixed tensor attribute with 2 dimensions when creating distance calculator.
Diffstat (limited to 'eval')
-rw-r--r--eval/src/tests/eval/value_type/value_type_test.cpp12
-rw-r--r--eval/src/vespa/eval/eval/value_type.cpp11
-rw-r--r--eval/src/vespa/eval/eval/value_type.h1
3 files changed, 24 insertions, 0 deletions
diff --git a/eval/src/tests/eval/value_type/value_type_test.cpp b/eval/src/tests/eval/value_type/value_type_test.cpp
index 245c9c30242..feb6fb5c368 100644
--- a/eval/src/tests/eval/value_type/value_type_test.cpp
+++ b/eval/src/tests/eval/value_type/value_type_test.cpp
@@ -327,6 +327,18 @@ TEST("require that nontrivial indexed dimensions can be obtained") {
TEST_DO(my_check(type("tensor(a[1],b[1],x[10],y{},z[1])").nontrivial_indexed_dimensions()));
}
+TEST("require that indexed dimensions can be obtained") {
+ auto my_check = [](const auto &list, size_t exp_size)
+ {
+ ASSERT_EQUAL(list.size(), 1u);
+ EXPECT_EQUAL(list[0].name, "x");
+ EXPECT_EQUAL(list[0].size, exp_size);
+ };
+ EXPECT_TRUE(type("double").indexed_dimensions().empty());
+ TEST_DO(my_check(type("tensor(x[10],y{})").indexed_dimensions(), 10));
+ TEST_DO(my_check(type("tensor(y{},x[1])").indexed_dimensions(), 1));
+}
+
TEST("require that mapped dimensions can be obtained") {
auto my_check = [](const auto &list)
{
diff --git a/eval/src/vespa/eval/eval/value_type.cpp b/eval/src/vespa/eval/eval/value_type.cpp
index dc5ce645a8c..7d088b22e06 100644
--- a/eval/src/vespa/eval/eval/value_type.cpp
+++ b/eval/src/vespa/eval/eval/value_type.cpp
@@ -251,6 +251,17 @@ ValueType::nontrivial_indexed_dimensions() const {
}
std::vector<ValueType::Dimension>
+ValueType::indexed_dimensions() const {
+ std::vector<ValueType::Dimension> result;
+ for (const auto &dim: dimensions()) {
+ if (dim.is_indexed()) {
+ result.push_back(dim);
+ }
+ }
+ return result;
+}
+
+std::vector<ValueType::Dimension>
ValueType::mapped_dimensions() const {
std::vector<ValueType::Dimension> result;
for (const auto &dim: dimensions()) {
diff --git a/eval/src/vespa/eval/eval/value_type.h b/eval/src/vespa/eval/eval/value_type.h
index b7a7c92e137..49f88edb2f9 100644
--- a/eval/src/vespa/eval/eval/value_type.h
+++ b/eval/src/vespa/eval/eval/value_type.h
@@ -66,6 +66,7 @@ public:
size_t dense_subspace_size() const;
const std::vector<Dimension> &dimensions() const { return _dimensions; }
std::vector<Dimension> nontrivial_indexed_dimensions() const;
+ std::vector<Dimension> indexed_dimensions() const;
std::vector<Dimension> mapped_dimensions() const;
size_t dimension_index(const vespalib::string &name) const;
std::vector<vespalib::string> dimension_names() const;