diff options
Diffstat (limited to 'document')
4 files changed, 27 insertions, 5 deletions
diff --git a/document/src/tests/datatype/datatype_test.cpp b/document/src/tests/datatype/datatype_test.cpp index 84c72127735..2fe4a8425db 100644 --- a/document/src/tests/datatype/datatype_test.cpp +++ b/document/src/tests/datatype/datatype_test.cpp @@ -81,6 +81,15 @@ TEST_F("require that TensorDataType can check for assignable tensor type", Tenso EXPECT_FALSE(f.isAssignableType("tensor(x{})")); } +TEST("TensorDataType implements equals() that takes underlying tensor type into consideration") +{ + auto a = TensorDataType::fromSpec("tensor<float>(x[4])"); + auto b = TensorDataType::fromSpec("tensor<bfloat16>(x[4])"); + EXPECT_EQUAL(*a, *a); + EXPECT_NOT_EQUAL(*a, *b); + EXPECT_NOT_EQUAL(*b, *a); +} + } // namespace TEST_MAIN() { TEST_RUN_ALL(); } diff --git a/document/src/vespa/document/datatype/datatype.h b/document/src/vespa/document/datatype/datatype.h index 1cd3f898be1..8deca98eb74 100644 --- a/document/src/vespa/document/datatype/datatype.h +++ b/document/src/vespa/document/datatype/datatype.h @@ -18,14 +18,15 @@ class FieldValue; class Field; class FieldPath; -class NumericDataType; -class PrimitiveDataType; -class DocumentType; -class WeightedSetDataType; -class CollectionDataType; class ArrayDataType; +class CollectionDataType; +class DocumentType; class MapDataType; +class NumericDataType; +class PrimitiveDataType; class ReferenceDataType; +class TensorDataType; +class WeightedSetDataType; class DataType : public Printable { @@ -120,6 +121,7 @@ public: virtual const CollectionDataType * cast_collection() const noexcept { return nullptr; } virtual const MapDataType * cast_map() const noexcept { return nullptr; } virtual const ReferenceDataType * cast_reference() const noexcept { return nullptr; } + virtual const TensorDataType* cast_tensor() const noexcept { return nullptr; } bool isMap() const { return cast_map() != nullptr; } /** diff --git a/document/src/vespa/document/datatype/tensor_data_type.cpp b/document/src/vespa/document/datatype/tensor_data_type.cpp index 820a7cf3dcd..99cda9df421 100644 --- a/document/src/vespa/document/datatype/tensor_data_type.cpp +++ b/document/src/vespa/document/datatype/tensor_data_type.cpp @@ -18,6 +18,15 @@ TensorDataType::TensorDataType(ValueType tensorType) TensorDataType::TensorDataType(const TensorDataType &) = default; TensorDataType::~TensorDataType() = default; +bool +TensorDataType::equals(const DataType& other) const noexcept +{ + if (!DataType::equals(other)) { + return false; + } + return _tensorType == other.cast_tensor()->_tensorType; +} + FieldValue::UP TensorDataType::createFieldValue() const { diff --git a/document/src/vespa/document/datatype/tensor_data_type.h b/document/src/vespa/document/datatype/tensor_data_type.h index b2f313f0778..f0afb976f14 100644 --- a/document/src/vespa/document/datatype/tensor_data_type.h +++ b/document/src/vespa/document/datatype/tensor_data_type.h @@ -18,6 +18,8 @@ public: ~TensorDataType(); bool isTensor() const noexcept override { return true; } + virtual const TensorDataType* cast_tensor() const noexcept override { return this; } + bool equals(const DataType& other) const noexcept override; std::unique_ptr<FieldValue> createFieldValue() const override; void print(std::ostream&, bool verbose, const std::string& indent) const override; static std::unique_ptr<const TensorDataType> fromSpec(const vespalib::string &spec); |