summaryrefslogtreecommitdiffstats
path: root/document
diff options
context:
space:
mode:
authorTor Egge <Tor.Egge@broadpark.no>2019-02-15 14:58:33 +0100
committerTor Egge <Tor.Egge@broadpark.no>2019-02-15 15:29:04 +0100
commit8cb47edf86487817d3391ecfa63bf6d0707b3c55 (patch)
tree076b9a2dc197a560d194297e3af589ca78a35906 /document
parent5c7396d271bdd02951854c4f2b7b14de8deeb5ad (diff)
Add check for tensor type being compatible with a tensor field value
using the current tensor data type.
Diffstat (limited to 'document')
-rw-r--r--document/src/tests/datatype/datatype_test.cpp43
-rw-r--r--document/src/vespa/document/datatype/tensor_data_type.cpp22
-rw-r--r--document/src/vespa/document/datatype/tensor_data_type.h1
3 files changed, 66 insertions, 0 deletions
diff --git a/document/src/tests/datatype/datatype_test.cpp b/document/src/tests/datatype/datatype_test.cpp
index 61d44fcfd5e..ef41ee770a2 100644
--- a/document/src/tests/datatype/datatype_test.cpp
+++ b/document/src/tests/datatype/datatype_test.cpp
@@ -4,7 +4,9 @@
#include <vespa/document/base/field.h>
#include <vespa/document/datatype/arraydatatype.h>
#include <vespa/document/datatype/structdatatype.h>
+#include <vespa/document/datatype/tensor_data_type.h>
#include <vespa/document/fieldvalue/longfieldvalue.h>
+#include <vespa/eval/eval/value_type.h>
#include <vespa/vespalib/testkit/testapp.h>
#include <vespa/vespalib/util/exceptions.h>
@@ -61,6 +63,47 @@ TEST("require that StructDataType can redeclare identical fields.") {
EXPECT_FALSE(s.hasField(field2.getName()));
}
+class TensorDataTypeFixture {
+ std::unique_ptr<const TensorDataType> _tensorDataType;
+public:
+ using ValueType = vespalib::eval::ValueType;
+ TensorDataTypeFixture()
+ : _tensorDataType()
+ {
+ }
+
+ ~TensorDataTypeFixture();
+
+ void setup(const vespalib::string &spec)
+ {
+ _tensorDataType = TensorDataType::fromSpec(spec);
+ }
+
+ bool isAssignableType(const vespalib::string &spec) const
+ {
+ auto assignType = ValueType::from_spec(spec);
+ return _tensorDataType->isAssignableType(assignType);
+ }
+};
+
+TensorDataTypeFixture::~TensorDataTypeFixture() = default;
+
+TEST_F("require that TensorDataType can check for assignable tensor type", TensorDataTypeFixture)
+{
+ f.setup("tensor(x[2])");
+ EXPECT_TRUE(f.isAssignableType("tensor(x[2])"));
+ EXPECT_FALSE(f.isAssignableType("tensor(x[3])"));
+ EXPECT_FALSE(f.isAssignableType("tensor(y[2])"));
+ EXPECT_FALSE(f.isAssignableType("tensor(x[])"));
+ EXPECT_FALSE(f.isAssignableType("tensor(x{})"));
+ f.setup("tensor(x[])");
+ EXPECT_TRUE(f.isAssignableType("tensor(x[2])"));
+ EXPECT_TRUE(f.isAssignableType("tensor(x[3])"));
+ EXPECT_FALSE(f.isAssignableType("tensor(y[2])"));
+ EXPECT_FALSE(f.isAssignableType("tensor(x[])"));
+ EXPECT_FALSE(f.isAssignableType("tensor(x{})"));
+}
+
} // namespace
TEST_MAIN() { TEST_RUN_ALL(); }
diff --git a/document/src/vespa/document/datatype/tensor_data_type.cpp b/document/src/vespa/document/datatype/tensor_data_type.cpp
index d3d747c045f..8aad39c68b7 100644
--- a/document/src/vespa/document/datatype/tensor_data_type.cpp
+++ b/document/src/vespa/document/datatype/tensor_data_type.cpp
@@ -49,4 +49,26 @@ TensorDataType::fromSpec(const vespalib::string &spec)
return std::make_unique<const TensorDataType>(ValueType::from_spec(spec));
}
+bool
+TensorDataType::isAssignableType(const vespalib::eval::ValueType &rhs) const
+{
+ const auto &dimensions = _tensorType.dimensions();
+ const auto &rhsDimensions = rhs.dimensions();
+ if (!rhs.is_tensor() || dimensions.size() != rhsDimensions.size()) {
+ return false;
+ }
+ for (size_t i = 0; i < dimensions.size(); ++i) {
+ const auto &dim = dimensions[i];
+ const auto &rhsDim = rhsDimensions[i];
+ if ((dim.name != rhsDim.name) ||
+ (dim.is_indexed() != rhsDim.is_indexed()) ||
+ (rhsDim.is_indexed() && !rhsDim.is_bound()) ||
+ (dim.is_bound() && (dim.size != rhsDim.size))) {
+ return false;
+ }
+ }
+ return true;
+
+}
+
} // document
diff --git a/document/src/vespa/document/datatype/tensor_data_type.h b/document/src/vespa/document/datatype/tensor_data_type.h
index acb243e96ff..43f14c3c3ea 100644
--- a/document/src/vespa/document/datatype/tensor_data_type.h
+++ b/document/src/vespa/document/datatype/tensor_data_type.h
@@ -24,6 +24,7 @@ public:
DECLARE_IDENTIFIABLE_ABSTRACT(TensorDataType);
const vespalib::eval::ValueType &getTensorType() const { return _tensorType; }
+ bool isAssignableType(const vespalib::eval::ValueType &tensorType) const;
};
}