diff options
author | Tor Egge <Tor.Egge@broadpark.no> | 2019-02-15 14:58:33 +0100 |
---|---|---|
committer | Tor Egge <Tor.Egge@broadpark.no> | 2019-02-15 15:29:04 +0100 |
commit | 8cb47edf86487817d3391ecfa63bf6d0707b3c55 (patch) | |
tree | 076b9a2dc197a560d194297e3af589ca78a35906 /document | |
parent | 5c7396d271bdd02951854c4f2b7b14de8deeb5ad (diff) |
Add check for tensor type being compatible with a tensor field value
using the current tensor data type.
Diffstat (limited to 'document')
-rw-r--r-- | document/src/tests/datatype/datatype_test.cpp | 43 | ||||
-rw-r--r-- | document/src/vespa/document/datatype/tensor_data_type.cpp | 22 | ||||
-rw-r--r-- | document/src/vespa/document/datatype/tensor_data_type.h | 1 |
3 files changed, 66 insertions, 0 deletions
diff --git a/document/src/tests/datatype/datatype_test.cpp b/document/src/tests/datatype/datatype_test.cpp index 61d44fcfd5e..ef41ee770a2 100644 --- a/document/src/tests/datatype/datatype_test.cpp +++ b/document/src/tests/datatype/datatype_test.cpp @@ -4,7 +4,9 @@ #include <vespa/document/base/field.h> #include <vespa/document/datatype/arraydatatype.h> #include <vespa/document/datatype/structdatatype.h> +#include <vespa/document/datatype/tensor_data_type.h> #include <vespa/document/fieldvalue/longfieldvalue.h> +#include <vespa/eval/eval/value_type.h> #include <vespa/vespalib/testkit/testapp.h> #include <vespa/vespalib/util/exceptions.h> @@ -61,6 +63,47 @@ TEST("require that StructDataType can redeclare identical fields.") { EXPECT_FALSE(s.hasField(field2.getName())); } +class TensorDataTypeFixture { + std::unique_ptr<const TensorDataType> _tensorDataType; +public: + using ValueType = vespalib::eval::ValueType; + TensorDataTypeFixture() + : _tensorDataType() + { + } + + ~TensorDataTypeFixture(); + + void setup(const vespalib::string &spec) + { + _tensorDataType = TensorDataType::fromSpec(spec); + } + + bool isAssignableType(const vespalib::string &spec) const + { + auto assignType = ValueType::from_spec(spec); + return _tensorDataType->isAssignableType(assignType); + } +}; + +TensorDataTypeFixture::~TensorDataTypeFixture() = default; + +TEST_F("require that TensorDataType can check for assignable tensor type", TensorDataTypeFixture) +{ + f.setup("tensor(x[2])"); + EXPECT_TRUE(f.isAssignableType("tensor(x[2])")); + EXPECT_FALSE(f.isAssignableType("tensor(x[3])")); + EXPECT_FALSE(f.isAssignableType("tensor(y[2])")); + EXPECT_FALSE(f.isAssignableType("tensor(x[])")); + EXPECT_FALSE(f.isAssignableType("tensor(x{})")); + f.setup("tensor(x[])"); + EXPECT_TRUE(f.isAssignableType("tensor(x[2])")); + EXPECT_TRUE(f.isAssignableType("tensor(x[3])")); + EXPECT_FALSE(f.isAssignableType("tensor(y[2])")); + EXPECT_FALSE(f.isAssignableType("tensor(x[])")); + EXPECT_FALSE(f.isAssignableType("tensor(x{})")); +} + } // namespace TEST_MAIN() { TEST_RUN_ALL(); } diff --git a/document/src/vespa/document/datatype/tensor_data_type.cpp b/document/src/vespa/document/datatype/tensor_data_type.cpp index d3d747c045f..8aad39c68b7 100644 --- a/document/src/vespa/document/datatype/tensor_data_type.cpp +++ b/document/src/vespa/document/datatype/tensor_data_type.cpp @@ -49,4 +49,26 @@ TensorDataType::fromSpec(const vespalib::string &spec) return std::make_unique<const TensorDataType>(ValueType::from_spec(spec)); } +bool +TensorDataType::isAssignableType(const vespalib::eval::ValueType &rhs) const +{ + const auto &dimensions = _tensorType.dimensions(); + const auto &rhsDimensions = rhs.dimensions(); + if (!rhs.is_tensor() || dimensions.size() != rhsDimensions.size()) { + return false; + } + for (size_t i = 0; i < dimensions.size(); ++i) { + const auto &dim = dimensions[i]; + const auto &rhsDim = rhsDimensions[i]; + if ((dim.name != rhsDim.name) || + (dim.is_indexed() != rhsDim.is_indexed()) || + (rhsDim.is_indexed() && !rhsDim.is_bound()) || + (dim.is_bound() && (dim.size != rhsDim.size))) { + return false; + } + } + return true; + +} + } // document diff --git a/document/src/vespa/document/datatype/tensor_data_type.h b/document/src/vespa/document/datatype/tensor_data_type.h index acb243e96ff..43f14c3c3ea 100644 --- a/document/src/vespa/document/datatype/tensor_data_type.h +++ b/document/src/vespa/document/datatype/tensor_data_type.h @@ -24,6 +24,7 @@ public: DECLARE_IDENTIFIABLE_ABSTRACT(TensorDataType); const vespalib::eval::ValueType &getTensorType() const { return _tensorType; } + bool isAssignableType(const vespalib::eval::ValueType &tensorType) const; }; } |