summaryrefslogtreecommitdiffstats
path: root/document
diff options
context:
space:
mode:
authorGeir Storli <geirst@yahooinc.com>2022-08-25 15:35:24 +0000
committerGeir Storli <geirst@yahooinc.com>2022-08-25 15:35:24 +0000
commitab30fd870df3f8b118185e3b52bf520458da0c33 (patch)
tree6d3b8133b3407db12d06e3226240fe9d2fc00672 /document
parent5417df91c88f3222f2a6bff95db7f1a349114d92 (diff)
Implement equals() for TensorDataType.
This fixes a bug where a tensor attribute is kept in the search backend after its tensor type has changed and the attribute aspect has been removed from the schema. The equals() function is used as part of DocumentTypeInspector::hasUnchangedField().
Diffstat (limited to 'document')
-rw-r--r--document/src/tests/datatype/datatype_test.cpp9
-rw-r--r--document/src/vespa/document/datatype/datatype.h12
-rw-r--r--document/src/vespa/document/datatype/tensor_data_type.cpp9
-rw-r--r--document/src/vespa/document/datatype/tensor_data_type.h2
4 files changed, 27 insertions, 5 deletions
diff --git a/document/src/tests/datatype/datatype_test.cpp b/document/src/tests/datatype/datatype_test.cpp
index 84c72127735..2fe4a8425db 100644
--- a/document/src/tests/datatype/datatype_test.cpp
+++ b/document/src/tests/datatype/datatype_test.cpp
@@ -81,6 +81,15 @@ TEST_F("require that TensorDataType can check for assignable tensor type", Tenso
EXPECT_FALSE(f.isAssignableType("tensor(x{})"));
}
+TEST("TensorDataType implements equals() that takes underlying tensor type into consideration")
+{
+ auto a = TensorDataType::fromSpec("tensor<float>(x[4])");
+ auto b = TensorDataType::fromSpec("tensor<bfloat16>(x[4])");
+ EXPECT_EQUAL(*a, *a);
+ EXPECT_NOT_EQUAL(*a, *b);
+ EXPECT_NOT_EQUAL(*b, *a);
+}
+
} // namespace
TEST_MAIN() { TEST_RUN_ALL(); }
diff --git a/document/src/vespa/document/datatype/datatype.h b/document/src/vespa/document/datatype/datatype.h
index 1cd3f898be1..8deca98eb74 100644
--- a/document/src/vespa/document/datatype/datatype.h
+++ b/document/src/vespa/document/datatype/datatype.h
@@ -18,14 +18,15 @@ class FieldValue;
class Field;
class FieldPath;
-class NumericDataType;
-class PrimitiveDataType;
-class DocumentType;
-class WeightedSetDataType;
-class CollectionDataType;
class ArrayDataType;
+class CollectionDataType;
+class DocumentType;
class MapDataType;
+class NumericDataType;
+class PrimitiveDataType;
class ReferenceDataType;
+class TensorDataType;
+class WeightedSetDataType;
class DataType : public Printable
{
@@ -120,6 +121,7 @@ public:
virtual const CollectionDataType * cast_collection() const noexcept { return nullptr; }
virtual const MapDataType * cast_map() const noexcept { return nullptr; }
virtual const ReferenceDataType * cast_reference() const noexcept { return nullptr; }
+ virtual const TensorDataType* cast_tensor() const noexcept { return nullptr; }
bool isMap() const { return cast_map() != nullptr; }
/**
diff --git a/document/src/vespa/document/datatype/tensor_data_type.cpp b/document/src/vespa/document/datatype/tensor_data_type.cpp
index 820a7cf3dcd..99cda9df421 100644
--- a/document/src/vespa/document/datatype/tensor_data_type.cpp
+++ b/document/src/vespa/document/datatype/tensor_data_type.cpp
@@ -18,6 +18,15 @@ TensorDataType::TensorDataType(ValueType tensorType)
TensorDataType::TensorDataType(const TensorDataType &) = default;
TensorDataType::~TensorDataType() = default;
+bool
+TensorDataType::equals(const DataType& other) const noexcept
+{
+ if (!DataType::equals(other)) {
+ return false;
+ }
+ return _tensorType == other.cast_tensor()->_tensorType;
+}
+
FieldValue::UP
TensorDataType::createFieldValue() const
{
diff --git a/document/src/vespa/document/datatype/tensor_data_type.h b/document/src/vespa/document/datatype/tensor_data_type.h
index b2f313f0778..f0afb976f14 100644
--- a/document/src/vespa/document/datatype/tensor_data_type.h
+++ b/document/src/vespa/document/datatype/tensor_data_type.h
@@ -18,6 +18,8 @@ public:
~TensorDataType();
bool isTensor() const noexcept override { return true; }
+ virtual const TensorDataType* cast_tensor() const noexcept override { return this; }
+ bool equals(const DataType& other) const noexcept override;
std::unique_ptr<FieldValue> createFieldValue() const override;
void print(std::ostream&, bool verbose, const std::string& indent) const override;
static std::unique_ptr<const TensorDataType> fromSpec(const vespalib::string &spec);