diff options
author | Geir Storli <geirst@yahooinc.com> | 2022-08-25 15:35:24 +0000 |
---|---|---|
committer | Geir Storli <geirst@yahooinc.com> | 2022-08-25 15:35:24 +0000 |
commit | ab30fd870df3f8b118185e3b52bf520458da0c33 (patch) | |
tree | 6d3b8133b3407db12d06e3226240fe9d2fc00672 | |
parent | 5417df91c88f3222f2a6bff95db7f1a349114d92 (diff) |
Implement equals() for TensorDataType.
This fixes a bug where a tensor attribute is kept in the search backend
after its tensor type has changed and the attribute aspect has been removed from the schema.
The equals() function is used as part of DocumentTypeInspector::hasUnchangedField().
4 files changed, 27 insertions, 5 deletions
diff --git a/document/src/tests/datatype/datatype_test.cpp b/document/src/tests/datatype/datatype_test.cpp index 84c72127735..2fe4a8425db 100644 --- a/document/src/tests/datatype/datatype_test.cpp +++ b/document/src/tests/datatype/datatype_test.cpp @@ -81,6 +81,15 @@ TEST_F("require that TensorDataType can check for assignable tensor type", Tenso EXPECT_FALSE(f.isAssignableType("tensor(x{})")); } +TEST("TensorDataType implements equals() that takes underlying tensor type into consideration") +{ + auto a = TensorDataType::fromSpec("tensor<float>(x[4])"); + auto b = TensorDataType::fromSpec("tensor<bfloat16>(x[4])"); + EXPECT_EQUAL(*a, *a); + EXPECT_NOT_EQUAL(*a, *b); + EXPECT_NOT_EQUAL(*b, *a); +} + } // namespace TEST_MAIN() { TEST_RUN_ALL(); } diff --git a/document/src/vespa/document/datatype/datatype.h b/document/src/vespa/document/datatype/datatype.h index 1cd3f898be1..8deca98eb74 100644 --- a/document/src/vespa/document/datatype/datatype.h +++ b/document/src/vespa/document/datatype/datatype.h @@ -18,14 +18,15 @@ class FieldValue; class Field; class FieldPath; -class NumericDataType; -class PrimitiveDataType; -class DocumentType; -class WeightedSetDataType; -class CollectionDataType; class ArrayDataType; +class CollectionDataType; +class DocumentType; class MapDataType; +class NumericDataType; +class PrimitiveDataType; class ReferenceDataType; +class TensorDataType; +class WeightedSetDataType; class DataType : public Printable { @@ -120,6 +121,7 @@ public: virtual const CollectionDataType * cast_collection() const noexcept { return nullptr; } virtual const MapDataType * cast_map() const noexcept { return nullptr; } virtual const ReferenceDataType * cast_reference() const noexcept { return nullptr; } + virtual const TensorDataType* cast_tensor() const noexcept { return nullptr; } bool isMap() const { return cast_map() != nullptr; } /** diff --git a/document/src/vespa/document/datatype/tensor_data_type.cpp b/document/src/vespa/document/datatype/tensor_data_type.cpp index 820a7cf3dcd..99cda9df421 100644 --- a/document/src/vespa/document/datatype/tensor_data_type.cpp +++ b/document/src/vespa/document/datatype/tensor_data_type.cpp @@ -18,6 +18,15 @@ TensorDataType::TensorDataType(ValueType tensorType) TensorDataType::TensorDataType(const TensorDataType &) = default; TensorDataType::~TensorDataType() = default; +bool +TensorDataType::equals(const DataType& other) const noexcept +{ + if (!DataType::equals(other)) { + return false; + } + return _tensorType == other.cast_tensor()->_tensorType; +} + FieldValue::UP TensorDataType::createFieldValue() const { diff --git a/document/src/vespa/document/datatype/tensor_data_type.h b/document/src/vespa/document/datatype/tensor_data_type.h index b2f313f0778..f0afb976f14 100644 --- a/document/src/vespa/document/datatype/tensor_data_type.h +++ b/document/src/vespa/document/datatype/tensor_data_type.h @@ -18,6 +18,8 @@ public: ~TensorDataType(); bool isTensor() const noexcept override { return true; } + virtual const TensorDataType* cast_tensor() const noexcept override { return this; } + bool equals(const DataType& other) const noexcept override; std::unique_ptr<FieldValue> createFieldValue() const override; void print(std::ostream&, bool verbose, const std::string& indent) const override; static std::unique_ptr<const TensorDataType> fromSpec(const vespalib::string &spec); |