summaryrefslogtreecommitdiffstats
path: root/document
diff options
context:
space:
mode:
Diffstat (limited to 'document')
-rw-r--r--document/src/tests/datatype/datatype_test.cpp9
-rw-r--r--document/src/vespa/document/datatype/datatype.h12
-rw-r--r--document/src/vespa/document/datatype/tensor_data_type.cpp9
-rw-r--r--document/src/vespa/document/datatype/tensor_data_type.h2
4 files changed, 27 insertions, 5 deletions
diff --git a/document/src/tests/datatype/datatype_test.cpp b/document/src/tests/datatype/datatype_test.cpp
index 84c72127735..2fe4a8425db 100644
--- a/document/src/tests/datatype/datatype_test.cpp
+++ b/document/src/tests/datatype/datatype_test.cpp
@@ -81,6 +81,15 @@ TEST_F("require that TensorDataType can check for assignable tensor type", Tenso
EXPECT_FALSE(f.isAssignableType("tensor(x{})"));
}
+TEST("TensorDataType implements equals() that takes underlying tensor type into consideration")
+{
+ auto a = TensorDataType::fromSpec("tensor<float>(x[4])");
+ auto b = TensorDataType::fromSpec("tensor<bfloat16>(x[4])");
+ EXPECT_EQUAL(*a, *a);
+ EXPECT_NOT_EQUAL(*a, *b);
+ EXPECT_NOT_EQUAL(*b, *a);
+}
+
} // namespace
TEST_MAIN() { TEST_RUN_ALL(); }
diff --git a/document/src/vespa/document/datatype/datatype.h b/document/src/vespa/document/datatype/datatype.h
index 1cd3f898be1..8deca98eb74 100644
--- a/document/src/vespa/document/datatype/datatype.h
+++ b/document/src/vespa/document/datatype/datatype.h
@@ -18,14 +18,15 @@ class FieldValue;
class Field;
class FieldPath;
-class NumericDataType;
-class PrimitiveDataType;
-class DocumentType;
-class WeightedSetDataType;
-class CollectionDataType;
class ArrayDataType;
+class CollectionDataType;
+class DocumentType;
class MapDataType;
+class NumericDataType;
+class PrimitiveDataType;
class ReferenceDataType;
+class TensorDataType;
+class WeightedSetDataType;
class DataType : public Printable
{
@@ -120,6 +121,7 @@ public:
virtual const CollectionDataType * cast_collection() const noexcept { return nullptr; }
virtual const MapDataType * cast_map() const noexcept { return nullptr; }
virtual const ReferenceDataType * cast_reference() const noexcept { return nullptr; }
+ virtual const TensorDataType* cast_tensor() const noexcept { return nullptr; }
bool isMap() const { return cast_map() != nullptr; }
/**
diff --git a/document/src/vespa/document/datatype/tensor_data_type.cpp b/document/src/vespa/document/datatype/tensor_data_type.cpp
index 820a7cf3dcd..99cda9df421 100644
--- a/document/src/vespa/document/datatype/tensor_data_type.cpp
+++ b/document/src/vespa/document/datatype/tensor_data_type.cpp
@@ -18,6 +18,15 @@ TensorDataType::TensorDataType(ValueType tensorType)
TensorDataType::TensorDataType(const TensorDataType &) = default;
TensorDataType::~TensorDataType() = default;
+bool
+TensorDataType::equals(const DataType& other) const noexcept
+{
+ if (!DataType::equals(other)) {
+ return false;
+ }
+ return _tensorType == other.cast_tensor()->_tensorType;
+}
+
FieldValue::UP
TensorDataType::createFieldValue() const
{
diff --git a/document/src/vespa/document/datatype/tensor_data_type.h b/document/src/vespa/document/datatype/tensor_data_type.h
index b2f313f0778..f0afb976f14 100644
--- a/document/src/vespa/document/datatype/tensor_data_type.h
+++ b/document/src/vespa/document/datatype/tensor_data_type.h
@@ -18,6 +18,8 @@ public:
~TensorDataType();
bool isTensor() const noexcept override { return true; }
+ virtual const TensorDataType* cast_tensor() const noexcept override { return this; }
+ bool equals(const DataType& other) const noexcept override;
std::unique_ptr<FieldValue> createFieldValue() const override;
void print(std::ostream&, bool verbose, const std::string& indent) const override;
static std::unique_ptr<const TensorDataType> fromSpec(const vespalib::string &spec);